pull_examples.rs

  1use anyhow::{Context as _, Result};
  2use flate2::read::GzDecoder;
  3use gpui::BackgroundExecutor;
  4use http_client::{AsyncBody, HttpClient, Method, Request};
  5use indoc::indoc;
  6use serde::Deserialize;
  7use serde_json::{Value as JsonValue, json};
  8use std::io::Read;
  9use std::sync::Arc;
 10use std::time::Duration;
 11
 12use zeta_prompt::ZetaPromptInput;
 13
 14use crate::example::Example;
 15use crate::progress::{InfoStyle, Progress, Step};
 16use edit_prediction::example_spec::{
 17    CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec,
 18    TelemetrySource,
 19};
 20use std::fmt::Write as _;
 21
 22const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
 23const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
 24const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
 25const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
 26const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
 27
 28const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
 29const POLL_INTERVAL: Duration = Duration::from_secs(2);
 30const MAX_POLL_ATTEMPTS: usize = 120;
 31
 32/// Parse an input token of the form `captured-after:{timestamp}`.
 33pub fn parse_captured_after_input(input: &str) -> Option<&str> {
 34    input.strip_prefix("captured-after:")
 35}
 36
 37/// Parse an input token of the form `rejected-after:{timestamp}`.
 38pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
 39    input.strip_prefix("rejected-after:")
 40}
 41
 42pub async fn fetch_captured_examples_after(
 43    http_client: Arc<dyn HttpClient>,
 44    after_timestamps: &[String],
 45    max_rows_per_timestamp: usize,
 46    background_executor: BackgroundExecutor,
 47) -> Result<Vec<Example>> {
 48    if after_timestamps.is_empty() {
 49        return Ok(Vec::new());
 50    }
 51
 52    let progress = Progress::global();
 53
 54    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 55        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 56    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 57        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 58    )?;
 59    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 60
 61    let mut all_examples = Vec::new();
 62
 63    for after_date in after_timestamps.iter() {
 64        let step_progress_name = format!(">{after_date}");
 65        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 66        step_progress.set_substatus("querying");
 67
 68        let statement = indoc! {r#"
 69            SELECT
 70                event_properties:example AS example
 71            FROM events
 72            WHERE event_type = ?
 73                AND time > TRY_TO_TIMESTAMP_NTZ(?)
 74            ORDER BY time ASC
 75            LIMIT ?
 76        "#};
 77
 78        let request = json!({
 79            "statement": statement,
 80            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 81            "database": "EVENTS",
 82            "schema": "PUBLIC",
 83            "warehouse": "DBT",
 84            "role": role,
 85            "bindings": {
 86                "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
 87                "2": { "type": "TEXT", "value": after_date },
 88                "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
 89            }
 90        });
 91
 92        let response = run_sql_with_polling(
 93            http_client.clone(),
 94            &base_url,
 95            &token,
 96            &request,
 97            &step_progress,
 98            background_executor.clone(),
 99        )
100        .await?;
101
102        let total_rows = response
103            .result_set_meta_data
104            .as_ref()
105            .and_then(|m| m.num_rows)
106            .unwrap_or(response.data.len() as i64);
107
108        let num_partitions = response
109            .result_set_meta_data
110            .as_ref()
111            .map(|m| m.partition_info.len())
112            .unwrap_or(1)
113            .max(1);
114
115        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
116        step_progress.set_substatus("parsing");
117
118        all_examples.extend(examples_from_response(&response)?);
119
120        if num_partitions > 1 {
121            let statement_handle = response
122                .statement_handle
123                .as_ref()
124                .context("response has multiple partitions but no statementHandle")?;
125
126            for partition in 1..num_partitions {
127                step_progress.set_substatus(format!(
128                    "fetching partition {}/{}",
129                    partition + 1,
130                    num_partitions
131                ));
132
133                let partition_response = fetch_partition(
134                    http_client.clone(),
135                    &base_url,
136                    &token,
137                    statement_handle,
138                    partition,
139                )
140                .await?;
141
142                all_examples.extend(examples_from_response(&partition_response)?);
143            }
144        }
145
146        step_progress.set_substatus("done");
147    }
148
149    Ok(all_examples)
150}
151
152#[derive(Debug, Clone, Deserialize)]
153#[serde(rename_all = "camelCase")]
154struct SnowflakeStatementResponse {
155    #[serde(default)]
156    data: Vec<Vec<JsonValue>>,
157    #[serde(default)]
158    result_set_meta_data: Option<SnowflakeResultSetMetaData>,
159    #[serde(default)]
160    code: Option<String>,
161    #[serde(default)]
162    message: Option<String>,
163    #[serde(default)]
164    statement_handle: Option<String>,
165}
166
167#[derive(Debug, Clone, Deserialize)]
168#[serde(rename_all = "camelCase")]
169struct SnowflakeResultSetMetaData {
170    #[serde(default, rename = "rowType")]
171    row_type: Vec<SnowflakeColumnMeta>,
172    #[serde(default)]
173    num_rows: Option<i64>,
174    #[serde(default)]
175    partition_info: Vec<SnowflakePartitionInfo>,
176}
177
178#[derive(Debug, Clone, Deserialize)]
179#[serde(rename_all = "camelCase")]
180struct SnowflakePartitionInfo {}
181
182#[derive(Debug, Clone, Deserialize)]
183struct SnowflakeColumnMeta {
184    #[serde(default)]
185    name: String,
186}
187
188fn examples_from_response(
189    response: &SnowflakeStatementResponse,
190) -> Result<impl Iterator<Item = Example> + '_> {
191    if let Some(code) = &response.code {
192        if code != SNOWFLAKE_SUCCESS_CODE {
193            anyhow::bail!(
194                "snowflake sql api returned error code={code} message={}",
195                response.message.as_deref().unwrap_or("<no message>")
196            );
197        }
198    }
199
200    let example_index = response
201        .result_set_meta_data
202        .as_ref()
203        .and_then(|m| {
204            m.row_type.iter().enumerate().find_map(|(index, col)| {
205                if col.name.eq_ignore_ascii_case("example") {
206                    Some(index)
207                } else {
208                    None
209                }
210            })
211        })
212        .unwrap_or(0);
213
214    let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
215        let Some(example_value) = data_row.get(example_index) else {
216            return None;
217        };
218        if example_value.is_null() {
219            return None;
220        }
221
222        let parse_result = match example_value {
223            JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
224            _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
225        };
226
227        match parse_result {
228            Ok(spec) => Some(Example {
229                spec,
230                prompt_inputs: None,
231                prompt: None,
232                predictions: Vec::new(),
233                score: Vec::new(),
234                state: None,
235            }),
236            Err(error) => {
237                let raw_json = serde_json::to_string_pretty(example_value)
238                    .unwrap_or_else(|_| "<failed to serialize json>".to_string());
239                log::error!(
240                    "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
241                );
242                None
243            }
244        }
245    });
246
247    Ok(iter)
248}
249
250async fn run_sql_with_polling(
251    http_client: Arc<dyn HttpClient>,
252    base_url: &str,
253    token: &str,
254    request: &serde_json::Value,
255    step_progress: &crate::progress::StepProgress,
256    background_executor: BackgroundExecutor,
257) -> Result<SnowflakeStatementResponse> {
258    let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
259
260    if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
261        let statement_handle = response
262            .statement_handle
263            .as_ref()
264            .context("async query response missing statementHandle")?
265            .clone();
266
267        for attempt in 1..=MAX_POLL_ATTEMPTS {
268            step_progress.set_substatus(format!("polling ({attempt})"));
269
270            background_executor.timer(POLL_INTERVAL).await;
271
272            response =
273                fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
274
275            if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
276                break;
277            }
278        }
279
280        if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
281            anyhow::bail!(
282                "query still running after {} poll attempts ({} seconds)",
283                MAX_POLL_ATTEMPTS,
284                MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
285            );
286        }
287    }
288
289    Ok(response)
290}
291
292async fn fetch_partition(
293    http_client: Arc<dyn HttpClient>,
294    base_url: &str,
295    token: &str,
296    statement_handle: &str,
297    partition: usize,
298) -> Result<SnowflakeStatementResponse> {
299    let url = format!(
300        "{}/api/v2/statements/{}?partition={}",
301        base_url.trim_end_matches('/'),
302        statement_handle,
303        partition
304    );
305
306    let http_request = Request::builder()
307        .method(Method::GET)
308        .uri(url.as_str())
309        .header("Authorization", format!("Bearer {token}"))
310        .header(
311            "X-Snowflake-Authorization-Token-Type",
312            "PROGRAMMATIC_ACCESS_TOKEN",
313        )
314        .header("Accept", "application/json")
315        .header("Accept-Encoding", "gzip")
316        .header("User-Agent", "edit_prediction_cli")
317        .body(AsyncBody::empty())?;
318
319    let response = http_client
320        .send(http_request)
321        .await
322        .context("failed to send partition request to Snowflake SQL API")?;
323
324    let status = response.status();
325    let content_encoding = response
326        .headers()
327        .get("content-encoding")
328        .and_then(|v| v.to_str().ok())
329        .map(|s| s.to_lowercase());
330
331    let body_bytes = {
332        use futures::AsyncReadExt as _;
333
334        let mut body = response.into_body();
335        let mut bytes = Vec::new();
336        body.read_to_end(&mut bytes)
337            .await
338            .context("failed to read Snowflake SQL API partition response body")?;
339        bytes
340    };
341
342    let body_bytes = if content_encoding.as_deref() == Some("gzip") {
343        let mut decoder = GzDecoder::new(&body_bytes[..]);
344        let mut decompressed = Vec::new();
345        decoder
346            .read_to_end(&mut decompressed)
347            .context("failed to decompress gzip response")?;
348        decompressed
349    } else {
350        body_bytes
351    };
352
353    if !status.is_success() && status.as_u16() != 202 {
354        let body_text = String::from_utf8_lossy(&body_bytes);
355        anyhow::bail!(
356            "snowflake sql api partition request http {}: {}",
357            status.as_u16(),
358            body_text
359        );
360    }
361
362    if body_bytes.is_empty() {
363        anyhow::bail!(
364            "snowflake sql api partition {} returned empty response body (http {})",
365            partition,
366            status.as_u16()
367        );
368    }
369
370    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
371        let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
372        format!(
373            "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
374            partition,
375            status.as_u16(),
376            body_preview
377        )
378    })
379}
380
381async fn run_sql(
382    http_client: Arc<dyn HttpClient>,
383    base_url: &str,
384    token: &str,
385    request: &serde_json::Value,
386) -> Result<SnowflakeStatementResponse> {
387    let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
388
389    let request_body =
390        serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
391
392    let http_request = Request::builder()
393        .method(Method::POST)
394        .uri(url.as_str())
395        .header("Authorization", format!("Bearer {token}"))
396        .header(
397            "X-Snowflake-Authorization-Token-Type",
398            "PROGRAMMATIC_ACCESS_TOKEN",
399        )
400        .header("Content-Type", "application/json")
401        .header("Accept", "application/json")
402        .header("User-Agent", "edit_prediction_cli")
403        .body(AsyncBody::from(request_body.clone()))?;
404
405    let response = http_client
406        .send(http_request)
407        .await
408        .context("failed to send request to Snowflake SQL API")?;
409
410    let status = response.status();
411    let body_bytes = {
412        use futures::AsyncReadExt as _;
413
414        let mut body = response.into_body();
415        let mut bytes = Vec::new();
416        body.read_to_end(&mut bytes)
417            .await
418            .context("failed to read Snowflake SQL API response body")?;
419        bytes
420    };
421
422    if !status.is_success() && status.as_u16() != 202 {
423        let body_text = String::from_utf8_lossy(&body_bytes);
424        anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
425    }
426
427    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
428        .context("failed to parse Snowflake SQL API response JSON")
429}
430
431pub async fn fetch_rejected_examples_after(
432    http_client: Arc<dyn HttpClient>,
433    after_timestamps: &[String],
434    max_rows_per_timestamp: usize,
435    background_executor: BackgroundExecutor,
436) -> Result<Vec<Example>> {
437    if after_timestamps.is_empty() {
438        return Ok(Vec::new());
439    }
440
441    let progress = Progress::global();
442
443    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
444        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
445    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
446        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
447    )?;
448    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
449
450    let mut all_examples = Vec::new();
451
452    for after_date in after_timestamps.iter() {
453        let step_progress_name = format!("rejected>{after_date}");
454        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
455        step_progress.set_substatus("querying");
456
457        // Join rejected events with their corresponding request events to get the full context.
458        // We filter for V3 sampling data which contains the structured input we need.
459        // We also filter for predictions that were actually shown to the user (was_shown = true)
460        // to focus on explicit user rejections rather than implicit cancellations.
461        let statement = indoc! {r#"
462            SELECT
463                req.event_properties:request_id::string AS request_id,
464                req.device_id::string AS device_id,
465                req.time::string AS time,
466                req.event_properties:input AS input,
467                req.event_properties:prompt::string AS prompt,
468                req.event_properties:output::string AS output,
469                rej.event_properties:was_shown::boolean AS was_shown,
470                rej.event_properties:reason::string AS reason
471            FROM events req
472            INNER JOIN events rej
473                ON req.event_properties:request_id = rej.event_properties:request_id
474            WHERE req.event_type = ?
475                AND rej.event_type = ?
476                AND req.event_properties:version = 'V3'
477                AND rej.event_properties:was_shown = true
478                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
479            ORDER BY req.time ASC
480            LIMIT ?
481        "#};
482
483        let request = json!({
484            "statement": statement,
485            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
486            "database": "EVENTS",
487            "schema": "PUBLIC",
488            "warehouse": "DBT",
489            "role": role,
490            "bindings": {
491                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
492                "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
493                "3": { "type": "TEXT", "value": after_date },
494                "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
495            }
496        });
497
498        let response = run_sql_with_polling(
499            http_client.clone(),
500            &base_url,
501            &token,
502            &request,
503            &step_progress,
504            background_executor.clone(),
505        )
506        .await?;
507
508        let total_rows = response
509            .result_set_meta_data
510            .as_ref()
511            .and_then(|m| m.num_rows)
512            .unwrap_or(response.data.len() as i64);
513
514        let num_partitions = response
515            .result_set_meta_data
516            .as_ref()
517            .map(|m| m.partition_info.len())
518            .unwrap_or(1)
519            .max(1);
520
521        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
522        step_progress.set_substatus("parsing");
523
524        all_examples.extend(rejected_examples_from_response(&response)?);
525
526        if num_partitions > 1 {
527            let statement_handle = response
528                .statement_handle
529                .as_ref()
530                .context("response has multiple partitions but no statementHandle")?;
531
532            for partition in 1..num_partitions {
533                step_progress.set_substatus(format!(
534                    "fetching partition {}/{}",
535                    partition + 1,
536                    num_partitions
537                ));
538
539                let partition_response = fetch_partition(
540                    http_client.clone(),
541                    &base_url,
542                    &token,
543                    statement_handle,
544                    partition,
545                )
546                .await?;
547
548                all_examples.extend(rejected_examples_from_response(&partition_response)?);
549            }
550        }
551
552        step_progress.set_substatus("done");
553    }
554
555    Ok(all_examples)
556}
557
558fn rejected_examples_from_response(
559    response: &SnowflakeStatementResponse,
560) -> Result<impl Iterator<Item = Example> + '_> {
561    if let Some(code) = &response.code {
562        if code != SNOWFLAKE_SUCCESS_CODE {
563            anyhow::bail!(
564                "snowflake sql api returned error code={code} message={}",
565                response.message.as_deref().unwrap_or("<no message>")
566            );
567        }
568    }
569
570    let column_indices = get_column_indices(
571        &response.result_set_meta_data,
572        &[
573            "request_id",
574            "device_id",
575            "time",
576            "input",
577            "prompt",
578            "output",
579            "was_shown",
580            "reason",
581        ],
582    );
583
584    let iter = response
585        .data
586        .iter()
587        .enumerate()
588        .filter_map(move |(row_index, data_row)| {
589            let get_string = |name: &str| -> Option<String> {
590                let index = column_indices.get(name).copied()?;
591                match data_row.get(index)? {
592                    JsonValue::String(s) => Some(s.clone()),
593                    JsonValue::Null => None,
594                    other => Some(other.to_string()),
595                }
596            };
597
598            let get_json = |name: &str| -> Option<JsonValue> {
599                let index = column_indices.get(name).copied()?;
600                let value = data_row.get(index)?;
601                if value.is_null() {
602                    return None;
603                }
604                match value {
605                    JsonValue::String(s) => serde_json::from_str(s).ok(),
606                    other => Some(other.clone()),
607                }
608            };
609
610            let get_bool = |name: &str| -> Option<bool> {
611                let index = column_indices.get(name).copied()?;
612                match data_row.get(index)? {
613                    JsonValue::Bool(b) => Some(*b),
614                    JsonValue::String(s) => s.parse().ok(),
615                    _ => None,
616                }
617            };
618
619            let request_id_str = get_string("request_id");
620            let device_id = get_string("device_id");
621            let time = get_string("time");
622            let input_json = get_json("input");
623            let input: Option<ZetaPromptInput> =
624                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
625            let output = get_string("output");
626            let was_shown = get_bool("was_shown");
627            let reason = get_string("reason");
628
629            match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
630                (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
631                    Some(build_rejected_example(
632                        request_id,
633                        device_id,
634                        time,
635                        input,
636                        output,
637                        was_shown,
638                        reason,
639                    ))
640                }
641                _ => {
642                    log::warn!(
643                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
644                        request_id_str.is_some(),
645                        device_id.is_some(),
646                        time.is_some(),
647                        input_json.is_some(),
648                        output.is_some(),
649                        was_shown.is_some(),
650                        reason.is_some()
651                    );
652                    None
653                }
654            }
655        });
656
657    Ok(iter)
658}
659
660fn build_rejected_example(
661    request_id: String,
662    device_id: String,
663    time: String,
664    input: ZetaPromptInput,
665    output: String,
666    was_shown: bool,
667    reason: String,
668) -> Example {
669    let events: Vec<CapturedEvent> = input
670        .events
671        .iter()
672        .map(|event| match event.as_ref() {
673            zeta_prompt::Event::BufferChange {
674                path,
675                old_path,
676                diff,
677                predicted,
678                in_open_source_repo,
679            } => CapturedEvent {
680                path: path.clone(),
681                old_path: old_path.clone(),
682                diff: diff.clone(),
683                predicted: *predicted,
684                in_open_source_repo: *in_open_source_repo,
685            },
686        })
687        .collect();
688
689    let related_files: Vec<CapturedRelatedFile> = input
690        .related_files
691        .iter()
692        .map(|rf| CapturedRelatedFile {
693            path: rf.path.clone(),
694            max_row: rf.max_row,
695            excerpts: rf
696                .excerpts
697                .iter()
698                .map(|e| CapturedRelatedExcerpt {
699                    row_range: e.row_range.clone(),
700                    text: e.text.to_string(),
701                })
702                .collect(),
703        })
704        .collect();
705
706    let cursor_excerpt = input.cursor_excerpt.as_ref();
707    let cursor_offset = input.cursor_offset_in_excerpt;
708
709    let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset);
710
711    let mut edit_history = String::new();
712    for event in &input.events {
713        zeta_prompt::write_event(&mut edit_history, event);
714        edit_history.push('\n');
715    }
716
717    let rejected_patch = build_rejected_patch(
718        &input.cursor_path,
719        cursor_excerpt,
720        &input.editable_range_in_excerpt,
721        &output,
722    );
723
724    let spec = ExampleSpec {
725        name: request_id.clone(),
726        repository_url: String::new(),
727        revision: String::new(),
728        tags: vec![format!("rejection:{}", reason.to_lowercase())],
729        reasoning: None,
730        uncommitted_diff: String::new(),
731        cursor_path: input.cursor_path.clone(),
732        cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
733        edit_history,
734        expected_patches: Vec::new(),
735        rejected_patch: Some(rejected_patch),
736        captured_prompt_input: Some(CapturedPromptInput {
737            cursor_file_content: cursor_excerpt.to_string(),
738            cursor_offset,
739            cursor_row,
740            cursor_column,
741            events,
742            related_files,
743        }),
744        telemetry: Some(TelemetrySource {
745            request_id,
746            device_id,
747            time,
748            rejection_reason: reason,
749            was_shown,
750        }),
751    };
752
753    Example {
754        spec,
755        prompt_inputs: None,
756        prompt: None,
757        predictions: Vec::new(),
758        score: Vec::new(),
759        state: None,
760    }
761}
762
763fn compute_row_column(text: &str, offset: usize) -> (u32, u32) {
764    let mut row = 0u32;
765    let mut last_newline_offset = 0;
766    for (i, c) in text.char_indices() {
767        if i >= offset {
768            break;
769        }
770        if c == '\n' {
771            row += 1;
772            last_newline_offset = i + 1;
773        }
774    }
775    let column = (offset - last_newline_offset) as u32;
776    (row, column)
777}
778
779fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
780    let before = &excerpt[..cursor_offset.min(excerpt.len())];
781    let after = &excerpt[cursor_offset.min(excerpt.len())..];
782    format!("{}[CURSOR_POSITION]{}", before, after)
783}
784
785fn build_rejected_patch(
786    cursor_path: &std::path::Path,
787    cursor_excerpt: &str,
788    editable_range: &std::ops::Range<usize>,
789    model_output: &str,
790) -> String {
791    let old_text = &cursor_excerpt[editable_range.clone()];
792
793    let editable_start_row = cursor_excerpt[..editable_range.start]
794        .chars()
795        .filter(|&c| c == '\n')
796        .count() as u32;
797
798    let diff_body = language::unified_diff_with_offsets(
799        old_text,
800        model_output,
801        editable_start_row,
802        editable_start_row,
803    );
804
805    let mut patch = String::new();
806    writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
807    writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
808    patch.push_str(&diff_body);
809    patch
810}
811
812fn get_column_indices(
813    meta: &Option<SnowflakeResultSetMetaData>,
814    names: &[&str],
815) -> std::collections::HashMap<String, usize> {
816    let mut indices = std::collections::HashMap::new();
817    if let Some(meta) = meta {
818        for (index, col) in meta.row_type.iter().enumerate() {
819            for &name in names {
820                if col.name.eq_ignore_ascii_case(name) {
821                    indices.insert(name.to_string(), index);
822                }
823            }
824        }
825    }
826    indices
827}