pull_examples.rs

  1use anyhow::{Context as _, Result};
  2use flate2::read::GzDecoder;
  3use gpui::BackgroundExecutor;
  4use http_client::{AsyncBody, HttpClient, Method, Request};
  5use indoc::indoc;
  6use serde::Deserialize;
  7use serde_json::{Value as JsonValue, json};
  8use std::io::Read;
  9use std::sync::Arc;
 10use std::time::Duration;
 11
 12use zeta_prompt::ZetaPromptInput;
 13
 14use crate::example::Example;
 15use crate::progress::{InfoStyle, Progress, Step};
 16use edit_prediction::example_spec::{
 17    CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec,
 18    TelemetrySource,
 19};
 20use std::fmt::Write as _;
 21
 22const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
 23const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
 24const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
 25const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
 26const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
 27
 28const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
 29const POLL_INTERVAL: Duration = Duration::from_secs(2);
 30const MAX_POLL_ATTEMPTS: usize = 120;
 31
 32/// Parse an input token of the form `captured-after:{timestamp}`.
 33pub fn parse_captured_after_input(input: &str) -> Option<&str> {
 34    input.strip_prefix("captured-after:")
 35}
 36
 37/// Parse an input token of the form `rejected-after:{timestamp}`.
 38pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
 39    input.strip_prefix("rejected-after:")
 40}
 41
 42pub async fn fetch_captured_examples_after(
 43    http_client: Arc<dyn HttpClient>,
 44    after_timestamps: &[String],
 45    max_rows_per_timestamp: usize,
 46    background_executor: BackgroundExecutor,
 47) -> Result<Vec<Example>> {
 48    if after_timestamps.is_empty() {
 49        return Ok(Vec::new());
 50    }
 51
 52    let progress = Progress::global();
 53
 54    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 55        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 56    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 57        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 58    )?;
 59    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 60
 61    let mut all_examples = Vec::new();
 62
 63    for after_date in after_timestamps.iter() {
 64        let step_progress_name = format!(">{after_date}");
 65        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 66        step_progress.set_substatus("querying");
 67
 68        let statement = indoc! {r#"
 69            SELECT
 70                event_properties:example AS example
 71            FROM events
 72            WHERE event_type = ?
 73                AND time > TRY_TO_TIMESTAMP_NTZ(?)
 74            ORDER BY time ASC
 75            LIMIT ?
 76        "#};
 77
 78        let request = json!({
 79            "statement": statement,
 80            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 81            "database": "EVENTS",
 82            "schema": "PUBLIC",
 83            "warehouse": "DBT",
 84            "role": role,
 85            "bindings": {
 86                "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
 87                "2": { "type": "TEXT", "value": after_date },
 88                "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
 89            }
 90        });
 91
 92        let response = run_sql_with_polling(
 93            http_client.clone(),
 94            &base_url,
 95            &token,
 96            &request,
 97            &step_progress,
 98            background_executor.clone(),
 99        )
100        .await?;
101
102        let total_rows = response
103            .result_set_meta_data
104            .as_ref()
105            .and_then(|m| m.num_rows)
106            .unwrap_or(response.data.len() as i64);
107
108        let num_partitions = response
109            .result_set_meta_data
110            .as_ref()
111            .map(|m| m.partition_info.len())
112            .unwrap_or(1)
113            .max(1);
114
115        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
116        step_progress.set_substatus("parsing");
117
118        all_examples.extend(examples_from_response(&response)?);
119
120        if num_partitions > 1 {
121            let statement_handle = response
122                .statement_handle
123                .as_ref()
124                .context("response has multiple partitions but no statementHandle")?;
125
126            for partition in 1..num_partitions {
127                step_progress.set_substatus(format!(
128                    "fetching partition {}/{}",
129                    partition + 1,
130                    num_partitions
131                ));
132
133                let partition_response = fetch_partition(
134                    http_client.clone(),
135                    &base_url,
136                    &token,
137                    statement_handle,
138                    partition,
139                )
140                .await?;
141
142                all_examples.extend(examples_from_response(&partition_response)?);
143            }
144        }
145
146        step_progress.set_substatus("done");
147    }
148
149    Ok(all_examples)
150}
151
152#[derive(Debug, Clone, Deserialize)]
153#[serde(rename_all = "camelCase")]
154struct SnowflakeStatementResponse {
155    #[serde(default)]
156    data: Vec<Vec<JsonValue>>,
157    #[serde(default)]
158    result_set_meta_data: Option<SnowflakeResultSetMetaData>,
159    #[serde(default)]
160    code: Option<String>,
161    #[serde(default)]
162    message: Option<String>,
163    #[serde(default)]
164    statement_handle: Option<String>,
165}
166
167#[derive(Debug, Clone, Deserialize)]
168#[serde(rename_all = "camelCase")]
169struct SnowflakeResultSetMetaData {
170    #[serde(default, rename = "rowType")]
171    row_type: Vec<SnowflakeColumnMeta>,
172    #[serde(default)]
173    num_rows: Option<i64>,
174    #[serde(default)]
175    partition_info: Vec<SnowflakePartitionInfo>,
176}
177
178#[derive(Debug, Clone, Deserialize)]
179#[serde(rename_all = "camelCase")]
180struct SnowflakePartitionInfo {}
181
182#[derive(Debug, Clone, Deserialize)]
183struct SnowflakeColumnMeta {
184    #[serde(default)]
185    name: String,
186}
187
188fn examples_from_response(
189    response: &SnowflakeStatementResponse,
190) -> Result<impl Iterator<Item = Example> + '_> {
191    if let Some(code) = &response.code {
192        if code != SNOWFLAKE_SUCCESS_CODE {
193            anyhow::bail!(
194                "snowflake sql api returned error code={code} message={}",
195                response.message.as_deref().unwrap_or("<no message>")
196            );
197        }
198    }
199
200    let example_index = response
201        .result_set_meta_data
202        .as_ref()
203        .and_then(|m| {
204            m.row_type.iter().enumerate().find_map(|(index, col)| {
205                if col.name.eq_ignore_ascii_case("example") {
206                    Some(index)
207                } else {
208                    None
209                }
210            })
211        })
212        .unwrap_or(0);
213
214    let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
215        let Some(example_value) = data_row.get(example_index) else {
216            return None;
217        };
218        if example_value.is_null() {
219            return None;
220        }
221
222        let parse_result = match example_value {
223            JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
224            _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
225        };
226
227        match parse_result {
228            Ok(spec) => Some(Example {
229                spec,
230                prompt_inputs: None,
231                prompt: None,
232                predictions: Vec::new(),
233                score: Vec::new(),
234                qa: Vec::new(),
235                state: None,
236            }),
237            Err(error) => {
238                let raw_json = serde_json::to_string_pretty(example_value)
239                    .unwrap_or_else(|_| "<failed to serialize json>".to_string());
240                log::error!(
241                    "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
242                );
243                None
244            }
245        }
246    });
247
248    Ok(iter)
249}
250
251async fn run_sql_with_polling(
252    http_client: Arc<dyn HttpClient>,
253    base_url: &str,
254    token: &str,
255    request: &serde_json::Value,
256    step_progress: &crate::progress::StepProgress,
257    background_executor: BackgroundExecutor,
258) -> Result<SnowflakeStatementResponse> {
259    let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
260
261    if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
262        let statement_handle = response
263            .statement_handle
264            .as_ref()
265            .context("async query response missing statementHandle")?
266            .clone();
267
268        for attempt in 1..=MAX_POLL_ATTEMPTS {
269            step_progress.set_substatus(format!("polling ({attempt})"));
270
271            background_executor.timer(POLL_INTERVAL).await;
272
273            response =
274                fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
275
276            if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
277                break;
278            }
279        }
280
281        if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
282            anyhow::bail!(
283                "query still running after {} poll attempts ({} seconds)",
284                MAX_POLL_ATTEMPTS,
285                MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
286            );
287        }
288    }
289
290    Ok(response)
291}
292
293async fn fetch_partition(
294    http_client: Arc<dyn HttpClient>,
295    base_url: &str,
296    token: &str,
297    statement_handle: &str,
298    partition: usize,
299) -> Result<SnowflakeStatementResponse> {
300    let url = format!(
301        "{}/api/v2/statements/{}?partition={}",
302        base_url.trim_end_matches('/'),
303        statement_handle,
304        partition
305    );
306
307    let http_request = Request::builder()
308        .method(Method::GET)
309        .uri(url.as_str())
310        .header("Authorization", format!("Bearer {token}"))
311        .header(
312            "X-Snowflake-Authorization-Token-Type",
313            "PROGRAMMATIC_ACCESS_TOKEN",
314        )
315        .header("Accept", "application/json")
316        .header("Accept-Encoding", "gzip")
317        .header("User-Agent", "edit_prediction_cli")
318        .body(AsyncBody::empty())?;
319
320    let response = http_client
321        .send(http_request)
322        .await
323        .context("failed to send partition request to Snowflake SQL API")?;
324
325    let status = response.status();
326    let content_encoding = response
327        .headers()
328        .get("content-encoding")
329        .and_then(|v| v.to_str().ok())
330        .map(|s| s.to_lowercase());
331
332    let body_bytes = {
333        use futures::AsyncReadExt as _;
334
335        let mut body = response.into_body();
336        let mut bytes = Vec::new();
337        body.read_to_end(&mut bytes)
338            .await
339            .context("failed to read Snowflake SQL API partition response body")?;
340        bytes
341    };
342
343    let body_bytes = if content_encoding.as_deref() == Some("gzip") {
344        let mut decoder = GzDecoder::new(&body_bytes[..]);
345        let mut decompressed = Vec::new();
346        decoder
347            .read_to_end(&mut decompressed)
348            .context("failed to decompress gzip response")?;
349        decompressed
350    } else {
351        body_bytes
352    };
353
354    if !status.is_success() && status.as_u16() != 202 {
355        let body_text = String::from_utf8_lossy(&body_bytes);
356        anyhow::bail!(
357            "snowflake sql api partition request http {}: {}",
358            status.as_u16(),
359            body_text
360        );
361    }
362
363    if body_bytes.is_empty() {
364        anyhow::bail!(
365            "snowflake sql api partition {} returned empty response body (http {})",
366            partition,
367            status.as_u16()
368        );
369    }
370
371    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
372        let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
373        format!(
374            "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
375            partition,
376            status.as_u16(),
377            body_preview
378        )
379    })
380}
381
382async fn run_sql(
383    http_client: Arc<dyn HttpClient>,
384    base_url: &str,
385    token: &str,
386    request: &serde_json::Value,
387) -> Result<SnowflakeStatementResponse> {
388    let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
389
390    let request_body =
391        serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
392
393    let http_request = Request::builder()
394        .method(Method::POST)
395        .uri(url.as_str())
396        .header("Authorization", format!("Bearer {token}"))
397        .header(
398            "X-Snowflake-Authorization-Token-Type",
399            "PROGRAMMATIC_ACCESS_TOKEN",
400        )
401        .header("Content-Type", "application/json")
402        .header("Accept", "application/json")
403        .header("User-Agent", "edit_prediction_cli")
404        .body(AsyncBody::from(request_body.clone()))?;
405
406    let response = http_client
407        .send(http_request)
408        .await
409        .context("failed to send request to Snowflake SQL API")?;
410
411    let status = response.status();
412    let body_bytes = {
413        use futures::AsyncReadExt as _;
414
415        let mut body = response.into_body();
416        let mut bytes = Vec::new();
417        body.read_to_end(&mut bytes)
418            .await
419            .context("failed to read Snowflake SQL API response body")?;
420        bytes
421    };
422
423    if !status.is_success() && status.as_u16() != 202 {
424        let body_text = String::from_utf8_lossy(&body_bytes);
425        anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
426    }
427
428    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
429        .context("failed to parse Snowflake SQL API response JSON")
430}
431
432pub async fn fetch_rejected_examples_after(
433    http_client: Arc<dyn HttpClient>,
434    after_timestamps: &[String],
435    max_rows_per_timestamp: usize,
436    background_executor: BackgroundExecutor,
437) -> Result<Vec<Example>> {
438    if after_timestamps.is_empty() {
439        return Ok(Vec::new());
440    }
441
442    let progress = Progress::global();
443
444    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
445        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
446    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
447        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
448    )?;
449    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
450
451    let mut all_examples = Vec::new();
452
453    for after_date in after_timestamps.iter() {
454        let step_progress_name = format!("rejected>{after_date}");
455        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
456        step_progress.set_substatus("querying");
457
458        // Join rejected events with their corresponding request events to get the full context.
459        // We filter for V3 sampling data which contains the structured input we need.
460        // We also filter for predictions that were actually shown to the user (was_shown = true)
461        // to focus on explicit user rejections rather than implicit cancellations.
462        let statement = indoc! {r#"
463            SELECT
464                req.event_properties:request_id::string AS request_id,
465                req.device_id::string AS device_id,
466                req.time::string AS time,
467                req.event_properties:input AS input,
468                req.event_properties:prompt::string AS prompt,
469                req.event_properties:output::string AS output,
470                rej.event_properties:was_shown::boolean AS was_shown,
471                rej.event_properties:reason::string AS reason
472            FROM events req
473            INNER JOIN events rej
474                ON req.event_properties:request_id = rej.event_properties:request_id
475            WHERE req.event_type = ?
476                AND rej.event_type = ?
477                AND req.event_properties:version = 'V3'
478                AND rej.event_properties:was_shown = true
479                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
480            ORDER BY req.time ASC
481            LIMIT ?
482        "#};
483
484        let request = json!({
485            "statement": statement,
486            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
487            "database": "EVENTS",
488            "schema": "PUBLIC",
489            "warehouse": "DBT",
490            "role": role,
491            "bindings": {
492                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
493                "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
494                "3": { "type": "TEXT", "value": after_date },
495                "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
496            }
497        });
498
499        let response = run_sql_with_polling(
500            http_client.clone(),
501            &base_url,
502            &token,
503            &request,
504            &step_progress,
505            background_executor.clone(),
506        )
507        .await?;
508
509        let total_rows = response
510            .result_set_meta_data
511            .as_ref()
512            .and_then(|m| m.num_rows)
513            .unwrap_or(response.data.len() as i64);
514
515        let num_partitions = response
516            .result_set_meta_data
517            .as_ref()
518            .map(|m| m.partition_info.len())
519            .unwrap_or(1)
520            .max(1);
521
522        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
523        step_progress.set_substatus("parsing");
524
525        all_examples.extend(rejected_examples_from_response(&response)?);
526
527        if num_partitions > 1 {
528            let statement_handle = response
529                .statement_handle
530                .as_ref()
531                .context("response has multiple partitions but no statementHandle")?;
532
533            for partition in 1..num_partitions {
534                step_progress.set_substatus(format!(
535                    "fetching partition {}/{}",
536                    partition + 1,
537                    num_partitions
538                ));
539
540                let partition_response = fetch_partition(
541                    http_client.clone(),
542                    &base_url,
543                    &token,
544                    statement_handle,
545                    partition,
546                )
547                .await?;
548
549                all_examples.extend(rejected_examples_from_response(&partition_response)?);
550            }
551        }
552
553        step_progress.set_substatus("done");
554    }
555
556    Ok(all_examples)
557}
558
559fn rejected_examples_from_response(
560    response: &SnowflakeStatementResponse,
561) -> Result<impl Iterator<Item = Example> + '_> {
562    if let Some(code) = &response.code {
563        if code != SNOWFLAKE_SUCCESS_CODE {
564            anyhow::bail!(
565                "snowflake sql api returned error code={code} message={}",
566                response.message.as_deref().unwrap_or("<no message>")
567            );
568        }
569    }
570
571    let column_indices = get_column_indices(
572        &response.result_set_meta_data,
573        &[
574            "request_id",
575            "device_id",
576            "time",
577            "input",
578            "prompt",
579            "output",
580            "was_shown",
581            "reason",
582        ],
583    );
584
585    let iter = response
586        .data
587        .iter()
588        .enumerate()
589        .filter_map(move |(row_index, data_row)| {
590            let get_string = |name: &str| -> Option<String> {
591                let index = column_indices.get(name).copied()?;
592                match data_row.get(index)? {
593                    JsonValue::String(s) => Some(s.clone()),
594                    JsonValue::Null => None,
595                    other => Some(other.to_string()),
596                }
597            };
598
599            let get_json = |name: &str| -> Option<JsonValue> {
600                let index = column_indices.get(name).copied()?;
601                let value = data_row.get(index)?;
602                if value.is_null() {
603                    return None;
604                }
605                match value {
606                    JsonValue::String(s) => serde_json::from_str(s).ok(),
607                    other => Some(other.clone()),
608                }
609            };
610
611            let get_bool = |name: &str| -> Option<bool> {
612                let index = column_indices.get(name).copied()?;
613                match data_row.get(index)? {
614                    JsonValue::Bool(b) => Some(*b),
615                    JsonValue::String(s) => s.parse().ok(),
616                    _ => None,
617                }
618            };
619
620            let request_id_str = get_string("request_id");
621            let device_id = get_string("device_id");
622            let time = get_string("time");
623            let input_json = get_json("input");
624            let input: Option<ZetaPromptInput> =
625                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
626            let output = get_string("output");
627            let was_shown = get_bool("was_shown");
628            let reason = get_string("reason");
629
630            match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
631                (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
632                    Some(build_rejected_example(
633                        request_id,
634                        device_id,
635                        time,
636                        input,
637                        output,
638                        was_shown,
639                        reason,
640                    ))
641                }
642                _ => {
643                    log::warn!(
644                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
645                        request_id_str.is_some(),
646                        device_id.is_some(),
647                        time.is_some(),
648                        input_json.is_some(),
649                        output.is_some(),
650                        was_shown.is_some(),
651                        reason.is_some()
652                    );
653                    None
654                }
655            }
656        });
657
658    Ok(iter)
659}
660
661fn build_rejected_example(
662    request_id: String,
663    device_id: String,
664    time: String,
665    input: ZetaPromptInput,
666    output: String,
667    was_shown: bool,
668    reason: String,
669) -> Example {
670    let events: Vec<CapturedEvent> = input
671        .events
672        .iter()
673        .map(|event| match event.as_ref() {
674            zeta_prompt::Event::BufferChange {
675                path,
676                old_path,
677                diff,
678                predicted,
679                in_open_source_repo,
680            } => CapturedEvent {
681                path: path.clone(),
682                old_path: old_path.clone(),
683                diff: diff.clone(),
684                predicted: *predicted,
685                in_open_source_repo: *in_open_source_repo,
686            },
687        })
688        .collect();
689
690    let related_files: Vec<CapturedRelatedFile> = input
691        .related_files
692        .iter()
693        .map(|rf| CapturedRelatedFile {
694            path: rf.path.clone(),
695            max_row: rf.max_row,
696            excerpts: rf
697                .excerpts
698                .iter()
699                .map(|e| CapturedRelatedExcerpt {
700                    row_range: e.row_range.clone(),
701                    text: e.text.to_string(),
702                })
703                .collect(),
704        })
705        .collect();
706
707    let cursor_excerpt = input.cursor_excerpt.as_ref();
708    let cursor_offset = input.cursor_offset_in_excerpt;
709
710    let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset);
711
712    let mut edit_history = String::new();
713    for event in &input.events {
714        zeta_prompt::write_event(&mut edit_history, event);
715        edit_history.push('\n');
716    }
717
718    let rejected_patch = build_rejected_patch(
719        &input.cursor_path,
720        cursor_excerpt,
721        &input.editable_range_in_excerpt,
722        &output,
723    );
724
725    let spec = ExampleSpec {
726        name: request_id.clone(),
727        repository_url: String::new(),
728        revision: String::new(),
729        tags: vec![format!("rejection:{}", reason.to_lowercase())],
730        reasoning: None,
731        uncommitted_diff: String::new(),
732        cursor_path: input.cursor_path.clone(),
733        cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
734        edit_history,
735        expected_patches: Vec::new(),
736        rejected_patch: Some(rejected_patch),
737        captured_prompt_input: Some(CapturedPromptInput {
738            cursor_file_content: cursor_excerpt.to_string(),
739            cursor_offset,
740            cursor_row,
741            cursor_column,
742            events,
743            related_files,
744        }),
745        telemetry: Some(TelemetrySource {
746            request_id,
747            device_id,
748            time,
749            rejection_reason: reason,
750            was_shown,
751        }),
752    };
753
754    Example {
755        spec,
756        prompt_inputs: None,
757        prompt: None,
758        predictions: Vec::new(),
759        score: Vec::new(),
760        qa: Vec::new(),
761        state: None,
762    }
763}
764
765fn compute_row_column(text: &str, offset: usize) -> (u32, u32) {
766    let mut row = 0u32;
767    let mut last_newline_offset = 0;
768    for (i, c) in text.char_indices() {
769        if i >= offset {
770            break;
771        }
772        if c == '\n' {
773            row += 1;
774            last_newline_offset = i + 1;
775        }
776    }
777    let column = (offset - last_newline_offset) as u32;
778    (row, column)
779}
780
781fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
782    let before = &excerpt[..cursor_offset.min(excerpt.len())];
783    let after = &excerpt[cursor_offset.min(excerpt.len())..];
784    format!("{}[CURSOR_POSITION]{}", before, after)
785}
786
787fn build_rejected_patch(
788    cursor_path: &std::path::Path,
789    cursor_excerpt: &str,
790    editable_range: &std::ops::Range<usize>,
791    model_output: &str,
792) -> String {
793    let old_text = &cursor_excerpt[editable_range.clone()];
794
795    let editable_start_row = cursor_excerpt[..editable_range.start]
796        .chars()
797        .filter(|&c| c == '\n')
798        .count() as u32;
799
800    let diff_body = language::unified_diff_with_offsets(
801        old_text,
802        model_output,
803        editable_start_row,
804        editable_start_row,
805    );
806
807    let mut patch = String::new();
808    writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
809    writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
810    patch.push_str(&diff_body);
811    patch
812}
813
814fn get_column_indices(
815    meta: &Option<SnowflakeResultSetMetaData>,
816    names: &[&str],
817) -> std::collections::HashMap<String, usize> {
818    let mut indices = std::collections::HashMap::new();
819    if let Some(meta) = meta {
820        for (index, col) in meta.row_type.iter().enumerate() {
821            for &name in names {
822                if col.name.eq_ignore_ascii_case(name) {
823                    indices.insert(name.to_string(), index);
824                }
825            }
826        }
827    }
828    indices
829}