pull_examples.rs

   1use anyhow::{Context as _, Result};
   2use flate2::read::GzDecoder;
   3use gpui::BackgroundExecutor;
   4use http_client::{AsyncBody, HttpClient, Method, Request};
   5use indoc::indoc;
   6use serde::Deserialize;
   7use serde_json::{Value as JsonValue, json};
   8use std::collections::HashMap;
   9use std::fmt::Write as _;
  10use std::io::Read;
  11use std::sync::Arc;
  12use std::time::Duration;
  13use telemetry_events::EditPredictionRating;
  14
  15use zeta_prompt::{ZetaFormat, ZetaPromptInput, excerpt_range_for_format};
  16
  17use crate::PredictionProvider;
  18use crate::example::{Example, ExamplePrompt};
  19use crate::progress::{InfoStyle, Progress, Step};
  20use edit_prediction::example_spec::{ExampleSpec, TelemetrySource};
  21
  22pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
  23pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
  24const SNOWFLAKE_TIMEOUT_CODE: &str = "000630";
  25
  26/// Minimum Zed version for filtering captured examples.
  27/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
  28/// where `zed_version >= 0.224.1`.
  29#[derive(Clone, Copy, Debug)]
  30pub struct MinCaptureVersion {
  31    pub minor: u32,
  32    pub patch: u32,
  33}
  34
  35pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
  36const PARTITION_FETCH_MAX_RETRIES: usize = 3;
  37const PARTITION_FETCH_RETRY_DELAYS: [Duration; PARTITION_FETCH_MAX_RETRIES] = [
  38    Duration::from_millis(500),
  39    Duration::from_secs(1),
  40    Duration::from_secs(2),
  41];
  42
  43/// Parse an input token of the form `captured-after:{timestamp}`.
  44pub fn parse_captured_after_input(input: &str) -> Option<&str> {
  45    input.strip_prefix("captured-after:")
  46}
  47
  48/// Parse an input token of the form `rejected-after:{timestamp}`.
  49pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
  50    input.strip_prefix("rejected-after:")
  51}
  52
  53/// Parse an input token of the form `requested-after:{timestamp}`.
  54pub fn parse_requested_after_input(input: &str) -> Option<&str> {
  55    input.strip_prefix("requested-after:")
  56}
  57
  58/// Parse an input token of the form `settled-after:{timestamp}`.
  59pub fn parse_settled_after_input(input: &str) -> Option<&str> {
  60    input.strip_prefix("settled-after:")
  61}
  62
  63/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
  64/// or `rated-negative-after:{timestamp}`.
  65/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
  66pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
  67    if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
  68        Some((timestamp, Some(EditPredictionRating::Positive)))
  69    } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
  70        Some((timestamp, Some(EditPredictionRating::Negative)))
  71    } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
  72        Some((timestamp, None))
  73    } else {
  74        None
  75    }
  76}
  77
  78#[derive(Debug, Clone, Deserialize)]
  79#[serde(rename_all = "camelCase")]
  80pub(crate) struct SnowflakeStatementResponse {
  81    #[serde(default)]
  82    pub(crate) data: Vec<Vec<JsonValue>>,
  83    #[serde(default)]
  84    pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
  85    #[serde(default)]
  86    pub(crate) code: Option<String>,
  87    #[serde(default)]
  88    pub(crate) message: Option<String>,
  89    #[serde(default)]
  90    pub(crate) statement_handle: Option<String>,
  91}
  92
  93#[derive(Debug, Clone, Deserialize)]
  94#[serde(rename_all = "camelCase")]
  95pub(crate) struct SnowflakeResultSetMetaData {
  96    #[serde(default, rename = "rowType")]
  97    row_type: Vec<SnowflakeColumnMeta>,
  98    #[serde(default)]
  99    num_rows: Option<i64>,
 100    #[serde(default)]
 101    partition_info: Vec<SnowflakePartitionInfo>,
 102}
 103
 104#[derive(Debug, Clone, Deserialize)]
 105#[serde(rename_all = "camelCase")]
 106struct SnowflakePartitionInfo {}
 107
 108#[derive(Debug, Clone, Deserialize)]
 109struct SnowflakeColumnMeta {
 110    #[serde(default)]
 111    name: String,
 112}
 113
 114async fn run_sql_with_polling(
 115    http_client: Arc<dyn HttpClient>,
 116    base_url: &str,
 117    token: &str,
 118    request: &serde_json::Value,
 119    step_progress: &crate::progress::StepProgress,
 120    background_executor: BackgroundExecutor,
 121) -> Result<SnowflakeStatementResponse> {
 122    let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
 123
 124    if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 125        let statement_handle = response
 126            .statement_handle
 127            .as_ref()
 128            .context("async query response missing statementHandle")?
 129            .clone();
 130
 131        for attempt in 0.. {
 132            step_progress.set_substatus(format!("polling ({attempt})"));
 133
 134            background_executor.timer(POLL_INTERVAL).await;
 135
 136            response = fetch_partition_with_retries(
 137                http_client.clone(),
 138                base_url,
 139                token,
 140                &statement_handle,
 141                0,
 142                background_executor.clone(),
 143            )
 144            .await?;
 145
 146            if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 147                break;
 148            }
 149        }
 150    }
 151
 152    Ok(response)
 153}
 154
 155struct SnowflakeConfig {
 156    token: String,
 157    base_url: String,
 158    role: Option<String>,
 159}
 160
 161#[derive(Clone)]
 162struct QueryRetryState {
 163    resume_after: String,
 164    remaining_limit: Option<usize>,
 165    offset: usize,
 166}
 167
 168async fn fetch_examples_with_query<MakeBindings>(
 169    http_client: Arc<dyn HttpClient>,
 170    step_progress: &crate::progress::StepProgress,
 171    background_executor: BackgroundExecutor,
 172    statement: &str,
 173    initial_retry_state: QueryRetryState,
 174    make_bindings: MakeBindings,
 175    required_columns: &[&str],
 176    parse_response: for<'a> fn(
 177        &'a SnowflakeStatementResponse,
 178        &'a HashMap<String, usize>,
 179    ) -> Result<Box<dyn Iterator<Item = Example> + 'a>>,
 180) -> Result<Vec<Example>>
 181where
 182    MakeBindings: Fn(&QueryRetryState) -> JsonValue,
 183{
 184    let snowflake = SnowflakeConfig {
 185        token: std::env::var("EP_SNOWFLAKE_API_KEY")
 186            .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?,
 187        base_url: std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 188            "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 189        )?,
 190        role: std::env::var("EP_SNOWFLAKE_ROLE").ok(),
 191    };
 192
 193    let mut requested_columns = required_columns.to_vec();
 194    if !requested_columns.contains(&"continuation_time") {
 195        requested_columns.push("continuation_time");
 196    }
 197
 198    let mut parsed_examples = Vec::new();
 199    let mut retry_state = initial_retry_state;
 200    let mut retry_count = 0usize;
 201
 202    loop {
 203        let bindings = make_bindings(&retry_state);
 204        let request = json!({
 205            "statement": statement,
 206            "database": "EVENTS",
 207            "schema": "PUBLIC",
 208            "warehouse": "DBT",
 209            "role": snowflake.role.as_deref(),
 210            "bindings": bindings
 211        });
 212
 213        let response = match run_sql_with_polling(
 214            http_client.clone(),
 215            &snowflake.base_url,
 216            &snowflake.token,
 217            &request,
 218            step_progress,
 219            background_executor.clone(),
 220        )
 221        .await
 222        {
 223            Ok(response) => response,
 224            Err(error) => {
 225                if is_snowflake_timeout_error(&error) && !parsed_examples.is_empty() {
 226                    retry_count += 1;
 227                    step_progress.set_substatus(format!(
 228                        "retrying from {} ({retry_count})",
 229                        retry_state.resume_after
 230                    ));
 231                    continue;
 232                }
 233
 234                return Err(error);
 235            }
 236        };
 237
 238        let total_rows = response
 239            .result_set_meta_data
 240            .as_ref()
 241            .and_then(|meta| meta.num_rows)
 242            .unwrap_or(response.data.len() as i64);
 243        let partition_count = response
 244            .result_set_meta_data
 245            .as_ref()
 246            .map(|meta| meta.partition_info.len())
 247            .unwrap_or(1)
 248            .max(1);
 249
 250        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 251        step_progress.set_substatus("parsing");
 252
 253        let column_indices = get_column_indices(&response.result_set_meta_data, &requested_columns);
 254        let mut rows_fetched_this_attempt = 0usize;
 255        let mut timed_out_fetching_partition = false;
 256
 257        parsed_examples.extend(parse_response(&response, &column_indices)?);
 258        rows_fetched_this_attempt += response.data.len();
 259        let mut last_continuation_time_this_attempt =
 260            last_continuation_timestamp_from_response(&response, &column_indices);
 261
 262        if partition_count > 1 {
 263            let statement_handle = response
 264                .statement_handle
 265                .as_ref()
 266                .context("response has multiple partitions but no statementHandle")?;
 267
 268            for partition in 1..partition_count {
 269                step_progress.set_substatus(format!(
 270                    "fetching partition {}/{}",
 271                    partition + 1,
 272                    partition_count
 273                ));
 274
 275                let partition_response = match fetch_partition_with_retries(
 276                    http_client.clone(),
 277                    &snowflake.base_url,
 278                    &snowflake.token,
 279                    statement_handle,
 280                    partition,
 281                    background_executor.clone(),
 282                )
 283                .await
 284                {
 285                    Ok(response) => response,
 286                    Err(error) => {
 287                        if is_snowflake_timeout_error(&error) && rows_fetched_this_attempt > 0 {
 288                            timed_out_fetching_partition = true;
 289                            break;
 290                        }
 291
 292                        return Err(error);
 293                    }
 294                };
 295
 296                parsed_examples.extend(parse_response(&partition_response, &column_indices)?);
 297                rows_fetched_this_attempt += partition_response.data.len();
 298
 299                if let Some(partition_continuation_time) =
 300                    last_continuation_timestamp_from_response(&partition_response, &column_indices)
 301                {
 302                    last_continuation_time_this_attempt = Some(partition_continuation_time);
 303                }
 304            }
 305        }
 306
 307        if rows_fetched_this_attempt == 0 {
 308            step_progress.set_substatus("done");
 309            return Ok(parsed_examples);
 310        }
 311
 312        if let Some(remaining_limit_value) = &mut retry_state.remaining_limit {
 313            *remaining_limit_value =
 314                remaining_limit_value.saturating_sub(rows_fetched_this_attempt);
 315            if *remaining_limit_value == 0 {
 316                step_progress.set_substatus("done");
 317                return Ok(parsed_examples);
 318            }
 319        }
 320
 321        if !timed_out_fetching_partition {
 322            step_progress.set_substatus("done");
 323            return Ok(parsed_examples);
 324        }
 325
 326        let Some(last_continuation_time_this_attempt) = last_continuation_time_this_attempt else {
 327            step_progress.set_substatus("done");
 328            return Ok(parsed_examples);
 329        };
 330
 331        retry_state.resume_after = last_continuation_time_this_attempt;
 332        retry_state.offset = 0;
 333        retry_count += 1;
 334        step_progress.set_substatus(format!(
 335            "retrying from {} ({retry_count})",
 336            retry_state.resume_after
 337        ));
 338    }
 339}
 340
 341pub(crate) async fn fetch_partition(
 342    http_client: Arc<dyn HttpClient>,
 343    base_url: &str,
 344    token: &str,
 345    statement_handle: &str,
 346    partition: usize,
 347) -> Result<SnowflakeStatementResponse> {
 348    let url = format!(
 349        "{}/api/v2/statements/{}?partition={}",
 350        base_url.trim_end_matches('/'),
 351        statement_handle,
 352        partition
 353    );
 354
 355    let http_request = Request::builder()
 356        .method(Method::GET)
 357        .uri(url.as_str())
 358        .header("Authorization", format!("Bearer {token}"))
 359        .header(
 360            "X-Snowflake-Authorization-Token-Type",
 361            "PROGRAMMATIC_ACCESS_TOKEN",
 362        )
 363        .header("Accept", "application/json")
 364        .header("Accept-Encoding", "gzip")
 365        .header("User-Agent", "edit_prediction_cli")
 366        .body(AsyncBody::empty())?;
 367
 368    let response = http_client
 369        .send(http_request)
 370        .await
 371        .context("failed to send partition request to Snowflake SQL API")?;
 372
 373    let status = response.status();
 374    let content_encoding = response
 375        .headers()
 376        .get("content-encoding")
 377        .and_then(|v| v.to_str().ok())
 378        .map(|s| s.to_lowercase());
 379
 380    let body_bytes = {
 381        use futures::AsyncReadExt as _;
 382
 383        let mut body = response.into_body();
 384        let mut bytes = Vec::new();
 385        body.read_to_end(&mut bytes)
 386            .await
 387            .context("failed to read Snowflake SQL API partition response body")?;
 388        bytes
 389    };
 390
 391    let body_bytes = if content_encoding.as_deref() == Some("gzip") {
 392        let mut decoder = GzDecoder::new(&body_bytes[..]);
 393        let mut decompressed = Vec::new();
 394        decoder
 395            .read_to_end(&mut decompressed)
 396            .context("failed to decompress gzip response")?;
 397        decompressed
 398    } else {
 399        body_bytes
 400    };
 401
 402    if !status.is_success() && status.as_u16() != 202 {
 403        let body_text = String::from_utf8_lossy(&body_bytes);
 404        anyhow::bail!(
 405            "snowflake sql api partition request http {}: {}",
 406            status.as_u16(),
 407            body_text
 408        );
 409    }
 410
 411    if body_bytes.is_empty() {
 412        anyhow::bail!(
 413            "snowflake sql api partition {} returned empty response body (http {})",
 414            partition,
 415            status.as_u16()
 416        );
 417    }
 418
 419    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
 420        let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
 421        format!(
 422            "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
 423            partition,
 424            status.as_u16(),
 425            body_preview
 426        )
 427    })
 428}
 429
 430async fn fetch_partition_with_retries(
 431    http_client: Arc<dyn HttpClient>,
 432    base_url: &str,
 433    token: &str,
 434    statement_handle: &str,
 435    partition: usize,
 436    background_executor: BackgroundExecutor,
 437) -> Result<SnowflakeStatementResponse> {
 438    let mut last_error = None;
 439
 440    for retry_attempt in 0..=PARTITION_FETCH_MAX_RETRIES {
 441        match fetch_partition(
 442            http_client.clone(),
 443            base_url,
 444            token,
 445            statement_handle,
 446            partition,
 447        )
 448        .await
 449        {
 450            Ok(response) => return Ok(response),
 451            Err(error) => {
 452                if retry_attempt == PARTITION_FETCH_MAX_RETRIES
 453                    || !is_transient_partition_fetch_error(&error)
 454                {
 455                    return Err(error);
 456                }
 457
 458                last_error = Some(error);
 459                background_executor
 460                    .timer(PARTITION_FETCH_RETRY_DELAYS[retry_attempt])
 461                    .await;
 462            }
 463        }
 464    }
 465
 466    match last_error {
 467        Some(error) => Err(error),
 468        None => anyhow::bail!("partition fetch retry loop exited without a result"),
 469    }
 470}
 471
 472fn is_transient_partition_fetch_error(error: &anyhow::Error) -> bool {
 473    error.chain().any(|cause| {
 474        let message = cause.to_string();
 475        message.contains("failed to read Snowflake SQL API partition response body")
 476            || message.contains("unexpected EOF")
 477            || message.contains("peer closed connection without sending TLS close_notify")
 478    })
 479}
 480
 481pub(crate) async fn run_sql(
 482    http_client: Arc<dyn HttpClient>,
 483    base_url: &str,
 484    token: &str,
 485    request: &serde_json::Value,
 486) -> Result<SnowflakeStatementResponse> {
 487    let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
 488
 489    let request_body =
 490        serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
 491
 492    let http_request = Request::builder()
 493        .method(Method::POST)
 494        .uri(url.as_str())
 495        .header("Authorization", format!("Bearer {token}"))
 496        .header(
 497            "X-Snowflake-Authorization-Token-Type",
 498            "PROGRAMMATIC_ACCESS_TOKEN",
 499        )
 500        .header("Content-Type", "application/json")
 501        .header("Accept", "application/json")
 502        .header("User-Agent", "edit_prediction_cli")
 503        .body(AsyncBody::from(request_body.clone()))?;
 504
 505    let response = http_client
 506        .send(http_request)
 507        .await
 508        .context("failed to send request to Snowflake SQL API")?;
 509
 510    let status = response.status();
 511    let body_bytes = {
 512        use futures::AsyncReadExt as _;
 513
 514        let mut body = response.into_body();
 515        let mut bytes = Vec::new();
 516        body.read_to_end(&mut bytes)
 517            .await
 518            .context("failed to read Snowflake SQL API response body")?;
 519        bytes
 520    };
 521
 522    let snowflake_response = serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
 523        .context("failed to parse Snowflake SQL API response JSON")?;
 524
 525    if !status.is_success() && status.as_u16() != 202 && !is_timeout_response(&snowflake_response) {
 526        let body_text = String::from_utf8_lossy(&body_bytes);
 527        anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
 528    }
 529
 530    if is_timeout_response(&snowflake_response) {
 531        anyhow::bail!(
 532            "snowflake sql api timed out code={} message={}",
 533            snowflake_response.code.as_deref().unwrap_or("<no code>"),
 534            snowflake_response
 535                .message
 536                .as_deref()
 537                .unwrap_or("<no message>")
 538        );
 539    }
 540
 541    Ok(snowflake_response)
 542}
 543
 544pub async fn fetch_rejected_examples_after(
 545    http_client: Arc<dyn HttpClient>,
 546    after_timestamps: &[String],
 547    max_rows_per_timestamp: Option<usize>,
 548    offset: usize,
 549    background_executor: BackgroundExecutor,
 550    min_capture_version: Option<MinCaptureVersion>,
 551) -> Result<Vec<Example>> {
 552    if after_timestamps.is_empty() {
 553        return Ok(Vec::new());
 554    }
 555
 556    let progress = Progress::global();
 557
 558    let mut all_examples = Vec::new();
 559
 560    for after_date in after_timestamps.iter() {
 561        let step_progress_name = format!("rejected>{after_date}");
 562        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 563        step_progress.set_substatus("querying");
 564
 565        let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
 566        let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
 567        let min_minor_str_ref = min_minor_str.as_deref();
 568        let min_patch_str_ref = min_patch_str.as_deref();
 569
 570        let statement = indoc! {r#"
 571            SELECT
 572                ep_request_id AS request_id,
 573                device_id AS device_id,
 574                requested_at::string AS continuation_time,
 575                requested_at::string AS time,
 576                input_payload AS input,
 577                prompt AS prompt,
 578                requested_output AS output,
 579                is_ep_shown_before_rejected AS was_shown,
 580                ep_rejected_reason AS reason,
 581                zed_version AS zed_version
 582            FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
 583            WHERE ep_outcome LIKE 'Rejected%'
 584                AND is_ep_shown_before_rejected = true
 585                AND requested_at > TRY_TO_TIMESTAMP_NTZ(?)
 586                AND (? IS NULL OR (
 587                    TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) > ?
 588                    OR (
 589                        TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) = ?
 590                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(zed_version, '.', 3), '+', 1) AS INTEGER) >= ?
 591                    )
 592                ))
 593            ORDER BY requested_at ASC
 594            LIMIT ?
 595            OFFSET ?
 596        "#};
 597
 598        let examples = fetch_examples_with_query(
 599            http_client.clone(),
 600            &step_progress,
 601            background_executor.clone(),
 602            statement,
 603            QueryRetryState {
 604                resume_after: after_date.clone(),
 605                remaining_limit: max_rows_per_timestamp,
 606                offset,
 607            },
 608            |retry_state| {
 609                json!({
 610                    "1": { "type": "TEXT", "value": retry_state.resume_after },
 611                    "2": { "type": "FIXED", "value": min_minor_str_ref },
 612                    "3": { "type": "FIXED", "value": min_minor_str_ref },
 613                    "4": { "type": "FIXED", "value": min_minor_str_ref },
 614                    "5": { "type": "FIXED", "value": min_patch_str_ref },
 615                    "6": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
 616                    "7": { "type": "FIXED", "value": retry_state.offset.to_string() }
 617                })
 618            },
 619            &[
 620                "request_id",
 621                "device_id",
 622                "time",
 623                "input",
 624                "prompt",
 625                "output",
 626                "was_shown",
 627                "reason",
 628                "zed_version",
 629            ],
 630            rejected_examples_from_response,
 631        )
 632        .await?;
 633
 634        all_examples.extend(examples);
 635    }
 636
 637    Ok(all_examples)
 638}
 639
 640fn format_limit(limit: Option<usize>) -> String {
 641    return limit.map(|l| l.to_string()).unwrap_or("NULL".to_string());
 642}
 643
 644pub async fn fetch_requested_examples_after(
 645    http_client: Arc<dyn HttpClient>,
 646    after_timestamps: &[String],
 647    max_rows_per_timestamp: Option<usize>,
 648    offset: usize,
 649    background_executor: BackgroundExecutor,
 650    min_capture_version: Option<MinCaptureVersion>,
 651) -> Result<Vec<Example>> {
 652    if after_timestamps.is_empty() {
 653        return Ok(Vec::new());
 654    }
 655
 656    let progress = Progress::global();
 657
 658    let mut all_examples = Vec::new();
 659
 660    for after_date in after_timestamps.iter() {
 661        let step_progress_name = format!("requested>{after_date}");
 662        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 663        step_progress.set_substatus("querying");
 664
 665        let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
 666        let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
 667        let min_minor_str_ref = min_minor_str.as_deref();
 668        let min_patch_str_ref = min_patch_str.as_deref();
 669
 670        let statement = indoc! {r#"
 671            SELECT
 672                ep_request_id AS request_id,
 673                device_id AS device_id,
 674                requested_at::string AS continuation_time,
 675                requested_at::string AS time,
 676                input_payload AS input,
 677                zed_version AS zed_version
 678            FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
 679            WHERE requested_at > TRY_TO_TIMESTAMP_NTZ(?)
 680                AND (? IS NULL OR (
 681                    TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) > ?
 682                    OR (
 683                        TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) = ?
 684                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(zed_version, '.', 3), '+', 1) AS INTEGER) >= ?
 685                    )
 686                ))
 687            ORDER BY requested_at ASC
 688            LIMIT ?
 689            OFFSET ?
 690        "#};
 691
 692        let examples = fetch_examples_with_query(
 693            http_client.clone(),
 694            &step_progress,
 695            background_executor.clone(),
 696            statement,
 697            QueryRetryState {
 698                resume_after: after_date.clone(),
 699                remaining_limit: max_rows_per_timestamp,
 700                offset,
 701            },
 702            |retry_state| {
 703                json!({
 704                    "1": { "type": "TEXT", "value": retry_state.resume_after },
 705                    "2": { "type": "FIXED", "value": min_minor_str_ref },
 706                    "3": { "type": "FIXED", "value": min_minor_str_ref },
 707                    "4": { "type": "FIXED", "value": min_minor_str_ref },
 708                    "5": { "type": "FIXED", "value": min_patch_str_ref },
 709                    "6": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
 710                    "7": { "type": "FIXED", "value": retry_state.offset.to_string() }
 711                })
 712            },
 713            &["request_id", "device_id", "time", "input", "zed_version"],
 714            requested_examples_from_response,
 715        )
 716        .await?;
 717
 718        all_examples.extend(examples);
 719    }
 720
 721    Ok(all_examples)
 722}
 723
 724pub async fn fetch_captured_examples_after(
 725    http_client: Arc<dyn HttpClient>,
 726    after_timestamps: &[String],
 727    max_rows_per_timestamp: Option<usize>,
 728    offset: usize,
 729    background_executor: BackgroundExecutor,
 730    min_capture_version: Option<MinCaptureVersion>,
 731) -> Result<Vec<Example>> {
 732    if after_timestamps.is_empty() {
 733        return Ok(Vec::new());
 734    }
 735
 736    let progress = Progress::global();
 737
 738    let mut all_examples = Vec::new();
 739
 740    for after_date in after_timestamps.iter() {
 741        let step_progress_name = format!("captured>{after_date}");
 742        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 743        step_progress.set_substatus("querying");
 744
 745        let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
 746        let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
 747        let min_minor_str_ref = min_minor_str.as_deref();
 748        let min_patch_str_ref = min_patch_str.as_deref();
 749
 750        let statement = indoc! {r#"
 751            SELECT
 752                ep_request_id AS request_id,
 753                device_id AS device_id,
 754                requested_at::string AS continuation_time,
 755                requested_at::string AS time,
 756                input_payload AS input,
 757                settled_editable_region AS settled_editable_region,
 758                example_payload AS example,
 759                zed_version AS zed_version
 760            FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
 761            WHERE settled_editable_region IS NOT NULL
 762                AND example_payload IS NOT NULL
 763                AND requested_at > TRY_TO_TIMESTAMP_NTZ(?)
 764                AND (? IS NULL OR (
 765                    TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) > ?
 766                    OR (
 767                        TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) = ?
 768                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(zed_version, '.', 3), '+', 1) AS INTEGER) >= ?
 769                    )
 770                ))
 771            ORDER BY requested_at ASC
 772            LIMIT ?
 773            OFFSET ?
 774        "#};
 775
 776        let examples = fetch_examples_with_query(
 777            http_client.clone(),
 778            &step_progress,
 779            background_executor.clone(),
 780            statement,
 781            QueryRetryState {
 782                resume_after: after_date.clone(),
 783                remaining_limit: max_rows_per_timestamp,
 784                offset,
 785            },
 786            |retry_state| {
 787                json!({
 788                    "1": { "type": "TEXT", "value": retry_state.resume_after },
 789                    "2": { "type": "FIXED", "value": min_minor_str_ref },
 790                    "3": { "type": "FIXED", "value": min_minor_str_ref },
 791                    "4": { "type": "FIXED", "value": min_minor_str_ref },
 792                    "5": { "type": "FIXED", "value": min_patch_str_ref },
 793                    "6": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
 794                    "7": { "type": "FIXED", "value": retry_state.offset.to_string() }
 795                })
 796            },
 797            &[
 798                "request_id",
 799                "device_id",
 800                "time",
 801                "input",
 802                "settled_editable_region",
 803                "example",
 804                "zed_version",
 805            ],
 806            captured_examples_from_response,
 807        )
 808        .await?;
 809
 810        all_examples.extend(examples);
 811    }
 812
 813    Ok(all_examples)
 814}
 815
 816pub async fn fetch_settled_examples_after(
 817    http_client: Arc<dyn HttpClient>,
 818    after_timestamps: &[String],
 819    max_rows_per_timestamp: Option<usize>,
 820    offset: usize,
 821    background_executor: BackgroundExecutor,
 822    min_capture_version: Option<MinCaptureVersion>,
 823) -> Result<Vec<Example>> {
 824    if after_timestamps.is_empty() {
 825        return Ok(Vec::new());
 826    }
 827
 828    let progress = Progress::global();
 829
 830    let mut all_examples = Vec::new();
 831
 832    for after_date in after_timestamps.iter() {
 833        let step_progress_name = format!("settled>{after_date}");
 834        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 835        step_progress.set_substatus("querying");
 836
 837        let _ = min_capture_version;
 838
 839        let statement = indoc! {r#"
 840            SELECT
 841                ep_request_id AS request_id,
 842                device_id AS device_id,
 843                requested_at::string AS continuation_time,
 844                requested_at::string AS time,
 845                input_payload AS input,
 846                requested_output AS requested_output,
 847                settled_editable_region AS settled_editable_region,
 848                requested_format AS requested_format,
 849                zed_version AS zed_version
 850            FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
 851            WHERE settled_editable_region IS NOT NULL
 852                AND requested_at > TRY_TO_TIMESTAMP_NTZ(?)
 853            ORDER BY requested_at ASC
 854            LIMIT ?
 855            OFFSET ?
 856        "#};
 857
 858        let examples = fetch_examples_with_query(
 859            http_client.clone(),
 860            &step_progress,
 861            background_executor.clone(),
 862            statement,
 863            QueryRetryState {
 864                resume_after: after_date.clone(),
 865                remaining_limit: max_rows_per_timestamp,
 866                offset,
 867            },
 868            |retry_state| {
 869                json!({
 870                    "1": { "type": "TEXT", "value": retry_state.resume_after },
 871                    "2": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
 872                    "3": { "type": "FIXED", "value": retry_state.offset.to_string() }
 873                })
 874            },
 875            &[
 876                "request_id",
 877                "device_id",
 878                "time",
 879                "input",
 880                "requested_output",
 881                "settled_editable_region",
 882                "requested_format",
 883                "zed_version",
 884            ],
 885            settled_examples_from_response,
 886        )
 887        .await?;
 888
 889        all_examples.extend(examples);
 890    }
 891
 892    Ok(all_examples)
 893}
 894
 895pub async fn fetch_rated_examples_after(
 896    http_client: Arc<dyn HttpClient>,
 897    inputs: &[(String, Option<EditPredictionRating>)],
 898    max_rows_per_timestamp: Option<usize>,
 899    offset: usize,
 900    background_executor: BackgroundExecutor,
 901    _min_capture_version: Option<MinCaptureVersion>,
 902) -> Result<Vec<Example>> {
 903    if inputs.is_empty() {
 904        return Ok(Vec::new());
 905    }
 906
 907    let progress = Progress::global();
 908
 909    let mut all_examples = Vec::new();
 910
 911    for (after_date, rating_filter) in inputs.iter() {
 912        let filter_label = match rating_filter {
 913            None => "",
 914            Some(EditPredictionRating::Positive) => ":positive",
 915            Some(EditPredictionRating::Negative) => ":negative",
 916        };
 917        let step_progress_name = format!("rated{filter_label}>{after_date}");
 918        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 919        step_progress.set_substatus("querying");
 920
 921        let rating_value = rating_filter.as_ref().map(|rating| match rating {
 922            EditPredictionRating::Positive => "Positive",
 923            EditPredictionRating::Negative => "Negative",
 924        });
 925
 926        let statement = indoc! {r#"
 927            SELECT
 928                ep_request_id AS request_id,
 929                rated_inputs AS inputs,
 930                rated_output AS output,
 931                rating AS rating,
 932                feedback AS feedback,
 933                device_id AS device_id,
 934                requested_at::string AS continuation_time,
 935                requested_at::string AS time,
 936                NULL AS experiment_name,
 937                NULL AS environment,
 938                zed_version AS zed_version
 939            FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
 940            WHERE rating IS NOT NULL
 941                AND (? IS NULL OR rating = ?)
 942                AND requested_at > TRY_TO_TIMESTAMP_NTZ(?)
 943                AND rated_inputs IS NOT NULL
 944                AND rated_inputs:cursor_excerpt IS NOT NULL
 945                AND rated_output IS NOT NULL
 946            ORDER BY requested_at ASC
 947            LIMIT ?
 948            OFFSET ?
 949        "#};
 950
 951        let examples = fetch_examples_with_query(
 952            http_client.clone(),
 953            &step_progress,
 954            background_executor.clone(),
 955            statement,
 956            QueryRetryState {
 957                resume_after: after_date.clone(),
 958                remaining_limit: max_rows_per_timestamp,
 959                offset,
 960            },
 961            |retry_state| {
 962                json!({
 963                    "1": { "type": "TEXT", "value": rating_value },
 964                    "2": { "type": "TEXT", "value": rating_value },
 965                    "3": { "type": "TEXT", "value": retry_state.resume_after },
 966                    "4": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
 967                    "5": { "type": "FIXED", "value": retry_state.offset.to_string() }
 968                })
 969            },
 970            &[
 971                "request_id",
 972                "inputs",
 973                "output",
 974                "rating",
 975                "feedback",
 976                "device_id",
 977                "time",
 978                "experiment_name",
 979                "environment",
 980                "zed_version",
 981            ],
 982            rated_examples_from_response,
 983        )
 984        .await?;
 985
 986        all_examples.extend(examples);
 987    }
 988
 989    Ok(all_examples)
 990}
 991
 992fn rated_examples_from_response<'a>(
 993    response: &'a SnowflakeStatementResponse,
 994    column_indices: &'a std::collections::HashMap<String, usize>,
 995) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
 996    if let Some(code) = &response.code {
 997        if code != SNOWFLAKE_SUCCESS_CODE {
 998            anyhow::bail!(
 999                "snowflake sql api returned error code={code} message={}",
1000                response.message.as_deref().unwrap_or("<no message>")
1001            );
1002        }
1003    }
1004
1005    let iter = response
1006        .data
1007        .iter()
1008        .enumerate()
1009        .filter_map(move |(row_index, data_row)| {
1010            let get_string = |name: &str| -> Option<String> {
1011                let index = column_indices.get(name).copied()?;
1012                match data_row.get(index)? {
1013                    JsonValue::String(s) => Some(s.clone()),
1014                    JsonValue::Null => None,
1015                    other => Some(other.to_string()),
1016                }
1017            };
1018
1019            let get_json = |name: &str| -> Option<JsonValue> {
1020                let index = column_indices.get(name).copied()?;
1021                let value = data_row.get(index)?;
1022                if value.is_null() {
1023                    return None;
1024                }
1025                match value {
1026                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1027                    other => Some(other.clone()),
1028                }
1029            };
1030
1031            let request_id = get_string("request_id");
1032            let inputs_json = get_json("inputs");
1033            let inputs: Option<ZetaPromptInput> = match &inputs_json {
1034                Some(v) => match serde_json::from_value(v.clone()) {
1035                    Ok(parsed) => Some(parsed),
1036                    Err(e) => {
1037                        log::warn!(
1038                            "skipping row {row_index}: failed to parse inputs - {e}",
1039                        );
1040                        return None;
1041                    }
1042                },
1043                None => None,
1044            };
1045            let output = get_string("output");
1046            let rating = get_string("rating");
1047            let feedback = get_string("feedback").unwrap_or_default();
1048            let device_id = get_string("device_id");
1049            let time = get_string("time");
1050            let experiment_name = get_string("experiment_name");
1051            let environment = get_string("environment");
1052            let zed_version = get_string("zed_version");
1053
1054            match (inputs, output.clone(), rating.clone(), time.clone()) {
1055                (Some(inputs), Some(output), Some(rating), Some(time)) => {
1056                    Some(build_rated_example(
1057                        request_id,
1058                        device_id.unwrap_or_default(),
1059                        time,
1060                        inputs,
1061                        output,
1062                        rating,
1063                        feedback,
1064                        experiment_name,
1065                        environment,
1066                        zed_version,
1067                    ))
1068                }
1069                _ => {
1070                    log::warn!(
1071                        "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} time={:?}",
1072                        inputs_json.is_some(),
1073                        output.is_some(),
1074                        rating.is_some(),
1075                        time.is_some(),
1076                    );
1077                    None
1078                }
1079            }
1080        });
1081
1082    Ok(Box::new(iter))
1083}
1084
1085fn build_rated_example(
1086    request_id: Option<String>,
1087    device_id: String,
1088    time: String,
1089    input: ZetaPromptInput,
1090    output: String,
1091    rating: String,
1092    feedback: String,
1093    experiment_name: Option<String>,
1094    environment: Option<String>,
1095    zed_version: Option<String>,
1096) -> Example {
1097    let parsed_rating = if rating == "Positive" {
1098        EditPredictionRating::Positive
1099    } else {
1100        EditPredictionRating::Negative
1101    };
1102    let is_positive = parsed_rating == EditPredictionRating::Positive;
1103    let request_id = request_id.unwrap_or_else(|| format!("rated-{}-{}", device_id, time));
1104
1105    let mut tags = Vec::with_capacity(3);
1106    tags.push(if is_positive {
1107        "rated:positive".to_string()
1108    } else {
1109        "rated:negative".to_string()
1110    });
1111    if let Some(experiment) = experiment_name {
1112        tags.push(format!("experiment:{experiment}"));
1113    }
1114    if let Some(env) = environment {
1115        tags.push(format!("environment:{env}"));
1116    }
1117
1118    let mut example =
1119        build_example_from_snowflake(request_id, device_id, time, input, tags, None, zed_version);
1120
1121    example.spec.rating = Some(parsed_rating);
1122
1123    if !feedback.is_empty() {
1124        example
1125            .spec
1126            .human_feedback
1127            .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
1128    }
1129
1130    if is_positive {
1131        example.spec.expected_patches = vec![output];
1132    } else {
1133        example.spec.rejected_patch = Some(output);
1134    }
1135
1136    example
1137}
1138
1139fn requested_examples_from_response<'a>(
1140    response: &'a SnowflakeStatementResponse,
1141    column_indices: &'a std::collections::HashMap<String, usize>,
1142) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1143    if let Some(code) = &response.code {
1144        if code != SNOWFLAKE_SUCCESS_CODE {
1145            anyhow::bail!(
1146                "snowflake sql api returned error code={code} message={}",
1147                response.message.as_deref().unwrap_or("<no message>")
1148            );
1149        }
1150    }
1151
1152    let iter = response
1153        .data
1154        .iter()
1155        .enumerate()
1156        .filter_map(move |(row_index, data_row)| {
1157            let get_string = |name: &str| -> Option<String> {
1158                let index = column_indices.get(name).copied()?;
1159                match data_row.get(index)? {
1160                    JsonValue::String(s) => Some(s.clone()),
1161                    JsonValue::Null => None,
1162                    other => Some(other.to_string()),
1163                }
1164            };
1165
1166            let get_json = |name: &str| -> Option<JsonValue> {
1167                let index = column_indices.get(name).copied()?;
1168                let value = data_row.get(index)?;
1169                if value.is_null() {
1170                    return None;
1171                }
1172                match value {
1173                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1174                    other => Some(other.clone()),
1175                }
1176            };
1177
1178            let request_id_str = get_string("request_id");
1179            let device_id = get_string("device_id");
1180            let time = get_string("time");
1181            let input_json = get_json("input");
1182            let input: Option<ZetaPromptInput> =
1183                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1184            let zed_version = get_string("zed_version");
1185
1186            match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
1187                (Some(request_id), Some(device_id), Some(time), Some(input)) => {
1188                    Some(build_example_from_snowflake(
1189                        request_id,
1190                        device_id,
1191                        time,
1192                        input,
1193                        vec!["requested".to_string()],
1194                        None,
1195                        zed_version,
1196                    ))
1197                }
1198                _ => {
1199                    log::warn!(
1200                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
1201                        request_id_str.is_some(),
1202                        device_id.is_some(),
1203                        time.is_some(),
1204                        input_json.is_some(),
1205                    );
1206                    None
1207                }
1208            }
1209        });
1210
1211    Ok(Box::new(iter))
1212}
1213
1214fn settled_examples_from_response<'a>(
1215    response: &'a SnowflakeStatementResponse,
1216    column_indices: &'a std::collections::HashMap<String, usize>,
1217) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1218    if let Some(code) = &response.code {
1219        if code != SNOWFLAKE_SUCCESS_CODE {
1220            anyhow::bail!(
1221                "snowflake sql api returned error code={code} message={}",
1222                response.message.as_deref().unwrap_or("<no message>")
1223            );
1224        }
1225    }
1226
1227    let iter = response
1228        .data
1229        .iter()
1230        .enumerate()
1231        .filter_map(move |(row_index, data_row)| {
1232            let get_value = |name: &str| -> Option<JsonValue> {
1233                let index = column_indices.get(name).copied()?;
1234                let value = data_row.get(index)?;
1235                if value.is_null() {
1236                    None
1237                } else {
1238                    Some(value.clone())
1239                }
1240            };
1241
1242            let get_string = |name: &str| -> Option<String> {
1243                match get_value(name)? {
1244                    JsonValue::String(s) => Some(s),
1245                    other => Some(other.to_string()),
1246                }
1247            };
1248
1249            let parse_json_value = |raw: Option<&JsonValue>| -> Option<JsonValue> {
1250                let value = raw?;
1251                match value {
1252                    JsonValue::String(s) => serde_json::from_str::<JsonValue>(s).ok(),
1253                    other => Some(other.clone()),
1254                }
1255            };
1256
1257            let request_id_str = get_string("request_id");
1258            let device_id = get_string("device_id");
1259            let time = get_string("time");
1260            let input_raw = get_value("input");
1261            let input_json = parse_json_value(input_raw.as_ref());
1262            let input: Option<ZetaPromptInput> = input_json
1263                .as_ref()
1264                .and_then(|parsed| serde_json::from_value(parsed.clone()).ok());
1265            let requested_output = get_string("requested_output");
1266            let settled_editable_region = get_string("settled_editable_region");
1267            let requested_format =
1268                get_string("requested_format").and_then(|s| ZetaFormat::parse(&s).ok());
1269            let zed_version = get_string("zed_version");
1270
1271            match (
1272                request_id_str.clone(),
1273                device_id.clone(),
1274                time.clone(),
1275                input.clone(),
1276                requested_output.clone(),
1277                settled_editable_region.clone(),
1278                requested_format,
1279            ) {
1280                (
1281                    Some(request_id),
1282                    Some(device_id),
1283                    Some(time),
1284                    Some(input),
1285                    Some(requested_output),
1286                    Some(settled_editable_region),
1287                    Some(requested_format),
1288                ) => Some(build_settled_example(
1289                    request_id,
1290                    device_id,
1291                    time,
1292                    input,
1293                    requested_output,
1294                    settled_editable_region,
1295                    requested_format,
1296                    zed_version,
1297                )),
1298                _ => {
1299                    let mut missing_fields = Vec::new();
1300
1301                    if request_id_str.is_none() {
1302                        missing_fields.push("request_id");
1303                    }
1304                    if device_id.is_none() {
1305                        missing_fields.push("device_id");
1306                    }
1307                    if time.is_none() {
1308                        missing_fields.push("time");
1309                    }
1310                    if input_raw.is_none() || input_json.is_none() || input.is_none() {
1311                        missing_fields.push("input");
1312                    }
1313                    if requested_output.is_none() {
1314                        missing_fields.push("requested_output");
1315                    }
1316                    if settled_editable_region.is_none() {
1317                        missing_fields.push("settled_editable_region");
1318                    }
1319                    if requested_format.is_none() {
1320                        missing_fields.push("requested_format");
1321                    }
1322
1323                    log::warn!(
1324                        "skipping settled row {row_index}: [{}]",
1325                        missing_fields.join(", "),
1326                    );
1327                    None
1328                }
1329            }
1330        });
1331
1332    Ok(Box::new(iter))
1333}
1334
1335fn captured_examples_from_response<'a>(
1336    response: &'a SnowflakeStatementResponse,
1337    column_indices: &'a std::collections::HashMap<String, usize>,
1338) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1339    if let Some(code) = &response.code {
1340        if code != SNOWFLAKE_SUCCESS_CODE {
1341            anyhow::bail!(
1342                "snowflake sql api returned error code={code} message={}",
1343                response.message.as_deref().unwrap_or("<no message>")
1344            );
1345        }
1346    }
1347
1348    let iter = response
1349        .data
1350        .iter()
1351        .enumerate()
1352        .filter_map(move |(row_index, data_row)| {
1353            let get_value = |name: &str| -> Option<JsonValue> {
1354                let index = column_indices.get(name).copied()?;
1355                let value = data_row.get(index)?;
1356                if value.is_null() {
1357                    None
1358                } else {
1359                    Some(value.clone())
1360                }
1361            };
1362
1363            let get_string = |name: &str| -> Option<String> {
1364                match get_value(name)? {
1365                    JsonValue::String(s) => Some(s),
1366                    other => Some(other.to_string()),
1367                }
1368            };
1369
1370            let parse_json_value = |raw: Option<&JsonValue>| -> Option<JsonValue> {
1371                let value = raw?;
1372                match value {
1373                    JsonValue::String(s) => serde_json::from_str::<JsonValue>(s).ok(),
1374                    other => Some(other.clone()),
1375                }
1376            };
1377
1378            let request_id = get_string("request_id");
1379            let device_id = get_string("device_id");
1380            let time = get_string("time");
1381            let input_raw = get_value("input");
1382            let input_json = parse_json_value(input_raw.as_ref());
1383            let input: Option<ZetaPromptInput> = input_json
1384                .as_ref()
1385                .and_then(|parsed| serde_json::from_value(parsed.clone()).ok());
1386            let example_raw = get_value("example");
1387            let example_json = parse_json_value(example_raw.as_ref());
1388            let example_spec: Option<ExampleSpec> = example_json.as_ref().and_then(|parsed| {
1389                serde_json::from_value(parsed.clone())
1390                    .or_else(|_| {
1391                        parsed
1392                            .as_str()
1393                            .and_then(|markdown| ExampleSpec::from_markdown(markdown).ok())
1394                            .ok_or_else(|| {
1395                                serde_json::Error::io(std::io::Error::other("not markdown"))
1396                            })
1397                    })
1398                    .ok()
1399            });
1400            let has_example_spec = example_spec.is_some();
1401            let settled_editable_region = get_string("settled_editable_region");
1402            let zed_version = get_string("zed_version");
1403
1404            match (
1405                request_id.clone(),
1406                device_id.clone(),
1407                time.clone(),
1408                input.clone(),
1409                example_spec,
1410                settled_editable_region.clone(),
1411            ) {
1412                (
1413                    Some(request_id),
1414                    Some(device_id),
1415                    Some(time),
1416                    Some(input),
1417                    Some(example_spec),
1418                    Some(settled_editable_region),
1419                ) => Some(build_captured_example(
1420                    request_id,
1421                    device_id,
1422                    time,
1423                    input,
1424                    example_spec,
1425                    settled_editable_region,
1426                    zed_version,
1427                )),
1428                _ => {
1429                    let mut missing_fields = Vec::new();
1430
1431                    if request_id.is_none() {
1432                        missing_fields.push("request_id");
1433                    }
1434                    if device_id.is_none() {
1435                        missing_fields.push("device_id");
1436                    }
1437                    if time.is_none() {
1438                        missing_fields.push("time");
1439                    }
1440                    if input_raw.is_none() || input_json.is_none() || input.is_none() {
1441                        missing_fields.push("input");
1442                    }
1443                    if example_raw.is_none() || !has_example_spec {
1444                        missing_fields.push("example");
1445                    }
1446                    if settled_editable_region.is_none() {
1447                        missing_fields.push("settled_editable_region");
1448                    }
1449
1450                    log::warn!(
1451                        "skipping captured row {row_index}: [{}]",
1452                        missing_fields.join(", "),
1453                    );
1454                    None
1455                }
1456            }
1457        });
1458
1459    Ok(Box::new(iter))
1460}
1461
1462fn build_settled_example(
1463    request_id: String,
1464    device_id: String,
1465    time: String,
1466    input: ZetaPromptInput,
1467    requested_output: String,
1468    settled_editable_region: String,
1469    requested_format: ZetaFormat,
1470    zed_version: Option<String>,
1471) -> Example {
1472    let requested_editable_range =
1473        excerpt_range_for_format(requested_format, &input.excerpt_ranges).0;
1474
1475    let base_cursor_excerpt = input.cursor_excerpt.to_string();
1476
1477    let requested_range_is_valid = requested_editable_range.start <= requested_editable_range.end
1478        && requested_editable_range.end <= base_cursor_excerpt.len();
1479    let mut example = build_example_from_snowflake(
1480        request_id.clone(),
1481        device_id,
1482        time,
1483        input,
1484        vec!["settled".to_string()],
1485        None,
1486        zed_version,
1487    );
1488
1489    if !requested_range_is_valid {
1490        log::warn!(
1491            "skipping malformed requested range for request {}: requested={:?} (base_len={})",
1492            request_id,
1493            requested_editable_range,
1494            base_cursor_excerpt.len(),
1495        );
1496        return example;
1497    }
1498
1499    let settled_replacement = settled_editable_region.as_str();
1500    let rejected_patch = build_output_patch(
1501        &example.spec.cursor_path,
1502        &base_cursor_excerpt,
1503        &requested_editable_range,
1504        &requested_output,
1505    );
1506    let expected_patch = build_output_patch(
1507        &example.spec.cursor_path,
1508        &base_cursor_excerpt,
1509        &requested_editable_range,
1510        settled_replacement,
1511    );
1512
1513    example.spec.expected_patches = vec![expected_patch];
1514    example.spec.rejected_patch = Some(rejected_patch);
1515    example
1516}
1517
1518fn build_captured_example(
1519    request_id: String,
1520    device_id: String,
1521    time: String,
1522    input: ZetaPromptInput,
1523    mut example_spec: ExampleSpec,
1524    settled_editable_region: String,
1525    zed_version: Option<String>,
1526) -> Example {
1527    let expected_patch = build_output_patch(
1528        &input.cursor_path,
1529        input.cursor_excerpt.as_ref(),
1530        &input.excerpt_ranges.editable_350,
1531        settled_editable_region.as_str(),
1532    );
1533
1534    example_spec.expected_patches = vec![expected_patch];
1535    example_spec.telemetry = Some(TelemetrySource {
1536        request_id,
1537        device_id,
1538        time,
1539        rejection_reason: String::new(),
1540        was_shown: false,
1541    });
1542
1543    Example {
1544        spec: example_spec,
1545        zed_version,
1546        prompt_inputs: Some(input),
1547        prompt: None,
1548        predictions: Vec::new(),
1549        score: Vec::new(),
1550        qa: Vec::new(),
1551        state: None,
1552    }
1553}
1554
1555fn rejected_examples_from_response<'a>(
1556    response: &'a SnowflakeStatementResponse,
1557    column_indices: &'a std::collections::HashMap<String, usize>,
1558) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1559    if let Some(code) = &response.code {
1560        if code != SNOWFLAKE_SUCCESS_CODE {
1561            anyhow::bail!(
1562                "snowflake sql api returned error code={code} message={}",
1563                response.message.as_deref().unwrap_or("<no message>")
1564            );
1565        }
1566    }
1567
1568    let iter = response
1569        .data
1570        .iter()
1571        .enumerate()
1572        .filter_map(move |(row_index, data_row)| {
1573            let get_string = |name: &str| -> Option<String> {
1574                let index = column_indices.get(name).copied()?;
1575                match data_row.get(index)? {
1576                    JsonValue::String(s) => Some(s.clone()),
1577                    JsonValue::Null => None,
1578                    other => Some(other.to_string()),
1579                }
1580            };
1581
1582            let get_json = |name: &str| -> Option<JsonValue> {
1583                let index = column_indices.get(name).copied()?;
1584                let value = data_row.get(index)?;
1585                if value.is_null() {
1586                    return None;
1587                }
1588                match value {
1589                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1590                    other => Some(other.clone()),
1591                }
1592            };
1593
1594            let get_bool = |name: &str| -> Option<bool> {
1595                let index = column_indices.get(name).copied()?;
1596                match data_row.get(index)? {
1597                    JsonValue::Bool(b) => Some(*b),
1598                    JsonValue::String(s) => s.parse().ok(),
1599                    _ => None,
1600                }
1601            };
1602
1603            let request_id_str = get_string("request_id");
1604            let device_id = get_string("device_id");
1605            let time = get_string("time");
1606            let input_json = get_json("input");
1607            let input: Option<ZetaPromptInput> =
1608                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1609            let prompt = get_string("prompt");
1610            let output = get_string("output");
1611            let was_shown = get_bool("was_shown");
1612            let reason = get_string("reason");
1613            let zed_version = get_string("zed_version");
1614
1615            match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1616                (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1617                    Some(build_rejected_example(
1618                        request_id,
1619                        device_id,
1620                        time,
1621                        input,
1622                        prompt,
1623                        output,
1624                        was_shown,
1625                        reason,
1626                        zed_version,
1627                    ))
1628                }
1629                _ => {
1630                    log::warn!(
1631                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1632                        request_id_str.is_some(),
1633                        device_id.is_some(),
1634                        time.is_some(),
1635                        input_json.is_some(),
1636                        output.is_some(),
1637                        was_shown.is_some(),
1638                        reason.is_some()
1639                    );
1640                    None
1641                }
1642            }
1643        });
1644
1645    Ok(Box::new(iter))
1646}
1647
1648fn build_rejected_example(
1649    request_id: String,
1650    device_id: String,
1651    time: String,
1652    input: ZetaPromptInput,
1653    prompt: Option<String>,
1654    output: String,
1655    was_shown: bool,
1656    reason: String,
1657    zed_version: Option<String>,
1658) -> Example {
1659    let rejected_patch = build_output_patch(
1660        &input.cursor_path,
1661        input.cursor_excerpt.as_ref(),
1662        &input.excerpt_ranges.editable_350,
1663        &output,
1664    );
1665    let mut example = build_example_from_snowflake(
1666        request_id,
1667        device_id,
1668        time,
1669        input,
1670        vec![format!("rejection:{}", reason.to_lowercase())],
1671        Some(RejectionInfo { reason, was_shown }),
1672        zed_version,
1673    );
1674    example.spec.rejected_patch = Some(rejected_patch);
1675    example.prompt = prompt.map(|prompt| ExamplePrompt {
1676        input: prompt,
1677        expected_output: String::new(),
1678        rejected_output: Some(output),
1679        prefill: None,
1680        provider: PredictionProvider::default(),
1681    });
1682    example
1683}
1684
1685struct RejectionInfo {
1686    reason: String,
1687    was_shown: bool,
1688}
1689
1690fn build_example_from_snowflake(
1691    request_id: String,
1692    device_id: String,
1693    time: String,
1694    input: ZetaPromptInput,
1695    tags: Vec<String>,
1696    rejection: Option<RejectionInfo>,
1697    zed_version: Option<String>,
1698) -> Example {
1699    let cursor_excerpt = input.cursor_excerpt.as_ref();
1700    let cursor_offset = input.cursor_offset_in_excerpt;
1701
1702    let mut edit_history = String::new();
1703    for event in &input.events {
1704        zeta_prompt::write_event(&mut edit_history, event);
1705        edit_history.push('\n');
1706    }
1707
1708    let (rejection_reason, was_shown) = match &rejection {
1709        Some(r) => (r.reason.clone(), r.was_shown),
1710        None => (String::new(), false),
1711    };
1712
1713    let spec = ExampleSpec {
1714        name: request_id.clone(),
1715        repository_url: String::new(),
1716        revision: String::new(),
1717        tags,
1718        reasoning: None,
1719        uncommitted_diff: String::new(),
1720        cursor_path: input.cursor_path.clone(),
1721        cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1722        edit_history,
1723        expected_patches: Vec::new(),
1724        rejected_patch: None,
1725        telemetry: Some(TelemetrySource {
1726            request_id,
1727            device_id,
1728            time,
1729            rejection_reason,
1730            was_shown,
1731        }),
1732        human_feedback: Vec::new(),
1733        rating: None,
1734    };
1735
1736    Example {
1737        spec,
1738        zed_version,
1739        prompt_inputs: Some(input),
1740        prompt: None,
1741        predictions: Vec::new(),
1742        score: Vec::new(),
1743        qa: Vec::new(),
1744        state: None,
1745    }
1746}
1747
1748fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1749    let before = &excerpt[..cursor_offset.min(excerpt.len())];
1750    let after = &excerpt[cursor_offset.min(excerpt.len())..];
1751    format!("{}[CURSOR_POSITION]{}", before, after)
1752}
1753
1754fn build_output_patch(
1755    cursor_path: &std::path::Path,
1756    cursor_excerpt: &str,
1757    editable_range: &std::ops::Range<usize>,
1758    model_output: &str,
1759) -> String {
1760    let old_text = &cursor_excerpt[editable_range.clone()];
1761
1762    let editable_start_row = cursor_excerpt[..editable_range.start]
1763        .chars()
1764        .filter(|&c| c == '\n')
1765        .count() as u32;
1766
1767    let diff_body = language::unified_diff_with_offsets(
1768        old_text,
1769        model_output,
1770        editable_start_row,
1771        editable_start_row,
1772    );
1773
1774    let mut patch = String::new();
1775    writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1776    writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1777    patch.push_str(&diff_body);
1778    patch
1779}
1780
1781fn is_timeout_response(response: &SnowflakeStatementResponse) -> bool {
1782    response.code.as_deref() == Some(SNOWFLAKE_TIMEOUT_CODE)
1783        && response
1784            .message
1785            .as_deref()
1786            .map(|message| message.to_ascii_lowercase().contains("timeout"))
1787            .unwrap_or(false)
1788}
1789
1790fn is_snowflake_timeout_error(error: &anyhow::Error) -> bool {
1791    error
1792        .chain()
1793        .any(|cause| cause.to_string().contains(SNOWFLAKE_TIMEOUT_CODE))
1794}
1795
1796fn last_continuation_timestamp_from_response(
1797    response: &SnowflakeStatementResponse,
1798    column_indices: &HashMap<String, usize>,
1799) -> Option<String> {
1800    let continuation_time_index = column_indices.get("continuation_time").copied()?;
1801    response
1802        .data
1803        .iter()
1804        .rev()
1805        .find_map(|row| match row.get(continuation_time_index)? {
1806            JsonValue::String(value) => Some(value.clone()),
1807            JsonValue::Null => None,
1808            other => Some(other.to_string()),
1809        })
1810}
1811
1812pub(crate) fn get_column_indices(
1813    meta: &Option<SnowflakeResultSetMetaData>,
1814    names: &[&str],
1815) -> HashMap<String, usize> {
1816    let mut indices = HashMap::new();
1817    if let Some(meta) = meta {
1818        for (index, col) in meta.row_type.iter().enumerate() {
1819            for &name in names {
1820                if col.name.eq_ignore_ascii_case(name) {
1821                    indices.insert(name.to_string(), index);
1822                }
1823            }
1824        }
1825    }
1826    indices
1827}