1use anyhow::{Context as _, Result};
2use flate2::read::GzDecoder;
3use gpui::BackgroundExecutor;
4use http_client::{AsyncBody, HttpClient, Method, Request};
5use indoc::indoc;
6use serde::Deserialize;
7use serde_json::{Value as JsonValue, json};
8use std::io::Read;
9use std::sync::Arc;
10use std::time::Duration;
11
12use zeta_prompt::ZetaPromptInput;
13
14use crate::example::Example;
15use crate::progress::{InfoStyle, Progress, Step};
16use edit_prediction::example_spec::{
17 CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec,
18 TelemetrySource,
19};
20use std::fmt::Write as _;
21
22const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
23const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
24const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
25const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
26const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
27
28const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
29const POLL_INTERVAL: Duration = Duration::from_secs(2);
30const MAX_POLL_ATTEMPTS: usize = 120;
31
32/// Parse an input token of the form `captured-after:{timestamp}`.
33pub fn parse_captured_after_input(input: &str) -> Option<&str> {
34 input.strip_prefix("captured-after:")
35}
36
37/// Parse an input token of the form `rejected-after:{timestamp}`.
38pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
39 input.strip_prefix("rejected-after:")
40}
41
42pub async fn fetch_captured_examples_after(
43 http_client: Arc<dyn HttpClient>,
44 after_timestamps: &[String],
45 max_rows_per_timestamp: usize,
46 background_executor: BackgroundExecutor,
47) -> Result<Vec<Example>> {
48 if after_timestamps.is_empty() {
49 return Ok(Vec::new());
50 }
51
52 let progress = Progress::global();
53
54 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
55 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
56 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
57 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
58 )?;
59 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
60
61 let mut all_examples = Vec::new();
62
63 for after_date in after_timestamps.iter() {
64 let step_progress_name = format!(">{after_date}");
65 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
66 step_progress.set_substatus("querying");
67
68 let statement = indoc! {r#"
69 SELECT
70 event_properties:example AS example
71 FROM events
72 WHERE event_type = ?
73 AND time > TRY_TO_TIMESTAMP_NTZ(?)
74 ORDER BY time ASC
75 LIMIT ?
76 "#};
77
78 let request = json!({
79 "statement": statement,
80 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
81 "database": "EVENTS",
82 "schema": "PUBLIC",
83 "warehouse": "DBT",
84 "role": role,
85 "bindings": {
86 "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
87 "2": { "type": "TEXT", "value": after_date },
88 "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
89 }
90 });
91
92 let response = run_sql_with_polling(
93 http_client.clone(),
94 &base_url,
95 &token,
96 &request,
97 &step_progress,
98 background_executor.clone(),
99 )
100 .await?;
101
102 let total_rows = response
103 .result_set_meta_data
104 .as_ref()
105 .and_then(|m| m.num_rows)
106 .unwrap_or(response.data.len() as i64);
107
108 let num_partitions = response
109 .result_set_meta_data
110 .as_ref()
111 .map(|m| m.partition_info.len())
112 .unwrap_or(1)
113 .max(1);
114
115 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
116 step_progress.set_substatus("parsing");
117
118 all_examples.extend(examples_from_response(&response)?);
119
120 if num_partitions > 1 {
121 let statement_handle = response
122 .statement_handle
123 .as_ref()
124 .context("response has multiple partitions but no statementHandle")?;
125
126 for partition in 1..num_partitions {
127 step_progress.set_substatus(format!(
128 "fetching partition {}/{}",
129 partition + 1,
130 num_partitions
131 ));
132
133 let partition_response = fetch_partition(
134 http_client.clone(),
135 &base_url,
136 &token,
137 statement_handle,
138 partition,
139 )
140 .await?;
141
142 all_examples.extend(examples_from_response(&partition_response)?);
143 }
144 }
145
146 step_progress.set_substatus("done");
147 }
148
149 Ok(all_examples)
150}
151
152#[derive(Debug, Clone, Deserialize)]
153#[serde(rename_all = "camelCase")]
154struct SnowflakeStatementResponse {
155 #[serde(default)]
156 data: Vec<Vec<JsonValue>>,
157 #[serde(default)]
158 result_set_meta_data: Option<SnowflakeResultSetMetaData>,
159 #[serde(default)]
160 code: Option<String>,
161 #[serde(default)]
162 message: Option<String>,
163 #[serde(default)]
164 statement_handle: Option<String>,
165}
166
167#[derive(Debug, Clone, Deserialize)]
168#[serde(rename_all = "camelCase")]
169struct SnowflakeResultSetMetaData {
170 #[serde(default, rename = "rowType")]
171 row_type: Vec<SnowflakeColumnMeta>,
172 #[serde(default)]
173 num_rows: Option<i64>,
174 #[serde(default)]
175 partition_info: Vec<SnowflakePartitionInfo>,
176}
177
178#[derive(Debug, Clone, Deserialize)]
179#[serde(rename_all = "camelCase")]
180struct SnowflakePartitionInfo {}
181
182#[derive(Debug, Clone, Deserialize)]
183struct SnowflakeColumnMeta {
184 #[serde(default)]
185 name: String,
186}
187
188fn examples_from_response(
189 response: &SnowflakeStatementResponse,
190) -> Result<impl Iterator<Item = Example> + '_> {
191 if let Some(code) = &response.code {
192 if code != SNOWFLAKE_SUCCESS_CODE {
193 anyhow::bail!(
194 "snowflake sql api returned error code={code} message={}",
195 response.message.as_deref().unwrap_or("<no message>")
196 );
197 }
198 }
199
200 let example_index = response
201 .result_set_meta_data
202 .as_ref()
203 .and_then(|m| {
204 m.row_type.iter().enumerate().find_map(|(index, col)| {
205 if col.name.eq_ignore_ascii_case("example") {
206 Some(index)
207 } else {
208 None
209 }
210 })
211 })
212 .unwrap_or(0);
213
214 let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
215 let Some(example_value) = data_row.get(example_index) else {
216 return None;
217 };
218 if example_value.is_null() {
219 return None;
220 }
221
222 let parse_result = match example_value {
223 JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
224 _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
225 };
226
227 match parse_result {
228 Ok(spec) => Some(Example {
229 spec,
230 prompt_inputs: None,
231 prompt: None,
232 predictions: Vec::new(),
233 score: Vec::new(),
234 qa: Vec::new(),
235 state: None,
236 }),
237 Err(error) => {
238 let raw_json = serde_json::to_string_pretty(example_value)
239 .unwrap_or_else(|_| "<failed to serialize json>".to_string());
240 log::error!(
241 "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
242 );
243 None
244 }
245 }
246 });
247
248 Ok(iter)
249}
250
251async fn run_sql_with_polling(
252 http_client: Arc<dyn HttpClient>,
253 base_url: &str,
254 token: &str,
255 request: &serde_json::Value,
256 step_progress: &crate::progress::StepProgress,
257 background_executor: BackgroundExecutor,
258) -> Result<SnowflakeStatementResponse> {
259 let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
260
261 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
262 let statement_handle = response
263 .statement_handle
264 .as_ref()
265 .context("async query response missing statementHandle")?
266 .clone();
267
268 for attempt in 1..=MAX_POLL_ATTEMPTS {
269 step_progress.set_substatus(format!("polling ({attempt})"));
270
271 background_executor.timer(POLL_INTERVAL).await;
272
273 response =
274 fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
275
276 if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
277 break;
278 }
279 }
280
281 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
282 anyhow::bail!(
283 "query still running after {} poll attempts ({} seconds)",
284 MAX_POLL_ATTEMPTS,
285 MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
286 );
287 }
288 }
289
290 Ok(response)
291}
292
293async fn fetch_partition(
294 http_client: Arc<dyn HttpClient>,
295 base_url: &str,
296 token: &str,
297 statement_handle: &str,
298 partition: usize,
299) -> Result<SnowflakeStatementResponse> {
300 let url = format!(
301 "{}/api/v2/statements/{}?partition={}",
302 base_url.trim_end_matches('/'),
303 statement_handle,
304 partition
305 );
306
307 let http_request = Request::builder()
308 .method(Method::GET)
309 .uri(url.as_str())
310 .header("Authorization", format!("Bearer {token}"))
311 .header(
312 "X-Snowflake-Authorization-Token-Type",
313 "PROGRAMMATIC_ACCESS_TOKEN",
314 )
315 .header("Accept", "application/json")
316 .header("Accept-Encoding", "gzip")
317 .header("User-Agent", "edit_prediction_cli")
318 .body(AsyncBody::empty())?;
319
320 let response = http_client
321 .send(http_request)
322 .await
323 .context("failed to send partition request to Snowflake SQL API")?;
324
325 let status = response.status();
326 let content_encoding = response
327 .headers()
328 .get("content-encoding")
329 .and_then(|v| v.to_str().ok())
330 .map(|s| s.to_lowercase());
331
332 let body_bytes = {
333 use futures::AsyncReadExt as _;
334
335 let mut body = response.into_body();
336 let mut bytes = Vec::new();
337 body.read_to_end(&mut bytes)
338 .await
339 .context("failed to read Snowflake SQL API partition response body")?;
340 bytes
341 };
342
343 let body_bytes = if content_encoding.as_deref() == Some("gzip") {
344 let mut decoder = GzDecoder::new(&body_bytes[..]);
345 let mut decompressed = Vec::new();
346 decoder
347 .read_to_end(&mut decompressed)
348 .context("failed to decompress gzip response")?;
349 decompressed
350 } else {
351 body_bytes
352 };
353
354 if !status.is_success() && status.as_u16() != 202 {
355 let body_text = String::from_utf8_lossy(&body_bytes);
356 anyhow::bail!(
357 "snowflake sql api partition request http {}: {}",
358 status.as_u16(),
359 body_text
360 );
361 }
362
363 if body_bytes.is_empty() {
364 anyhow::bail!(
365 "snowflake sql api partition {} returned empty response body (http {})",
366 partition,
367 status.as_u16()
368 );
369 }
370
371 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
372 let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
373 format!(
374 "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
375 partition,
376 status.as_u16(),
377 body_preview
378 )
379 })
380}
381
382async fn run_sql(
383 http_client: Arc<dyn HttpClient>,
384 base_url: &str,
385 token: &str,
386 request: &serde_json::Value,
387) -> Result<SnowflakeStatementResponse> {
388 let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
389
390 let request_body =
391 serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
392
393 let http_request = Request::builder()
394 .method(Method::POST)
395 .uri(url.as_str())
396 .header("Authorization", format!("Bearer {token}"))
397 .header(
398 "X-Snowflake-Authorization-Token-Type",
399 "PROGRAMMATIC_ACCESS_TOKEN",
400 )
401 .header("Content-Type", "application/json")
402 .header("Accept", "application/json")
403 .header("User-Agent", "edit_prediction_cli")
404 .body(AsyncBody::from(request_body.clone()))?;
405
406 let response = http_client
407 .send(http_request)
408 .await
409 .context("failed to send request to Snowflake SQL API")?;
410
411 let status = response.status();
412 let body_bytes = {
413 use futures::AsyncReadExt as _;
414
415 let mut body = response.into_body();
416 let mut bytes = Vec::new();
417 body.read_to_end(&mut bytes)
418 .await
419 .context("failed to read Snowflake SQL API response body")?;
420 bytes
421 };
422
423 if !status.is_success() && status.as_u16() != 202 {
424 let body_text = String::from_utf8_lossy(&body_bytes);
425 anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
426 }
427
428 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
429 .context("failed to parse Snowflake SQL API response JSON")
430}
431
432pub async fn fetch_rejected_examples_after(
433 http_client: Arc<dyn HttpClient>,
434 after_timestamps: &[String],
435 max_rows_per_timestamp: usize,
436 background_executor: BackgroundExecutor,
437) -> Result<Vec<Example>> {
438 if after_timestamps.is_empty() {
439 return Ok(Vec::new());
440 }
441
442 let progress = Progress::global();
443
444 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
445 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
446 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
447 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
448 )?;
449 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
450
451 let mut all_examples = Vec::new();
452
453 for after_date in after_timestamps.iter() {
454 let step_progress_name = format!("rejected>{after_date}");
455 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
456 step_progress.set_substatus("querying");
457
458 // Join rejected events with their corresponding request events to get the full context.
459 // We filter for V3 sampling data which contains the structured input we need.
460 // We also filter for predictions that were actually shown to the user (was_shown = true)
461 // to focus on explicit user rejections rather than implicit cancellations.
462 let statement = indoc! {r#"
463 SELECT
464 req.event_properties:request_id::string AS request_id,
465 req.device_id::string AS device_id,
466 req.time::string AS time,
467 req.event_properties:input AS input,
468 req.event_properties:prompt::string AS prompt,
469 req.event_properties:output::string AS output,
470 rej.event_properties:was_shown::boolean AS was_shown,
471 rej.event_properties:reason::string AS reason
472 FROM events req
473 INNER JOIN events rej
474 ON req.event_properties:request_id = rej.event_properties:request_id
475 WHERE req.event_type = ?
476 AND rej.event_type = ?
477 AND req.event_properties:version = 'V3'
478 AND rej.event_properties:was_shown = true
479 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
480 ORDER BY req.time ASC
481 LIMIT ?
482 "#};
483
484 let request = json!({
485 "statement": statement,
486 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
487 "database": "EVENTS",
488 "schema": "PUBLIC",
489 "warehouse": "DBT",
490 "role": role,
491 "bindings": {
492 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
493 "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
494 "3": { "type": "TEXT", "value": after_date },
495 "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
496 }
497 });
498
499 let response = run_sql_with_polling(
500 http_client.clone(),
501 &base_url,
502 &token,
503 &request,
504 &step_progress,
505 background_executor.clone(),
506 )
507 .await?;
508
509 let total_rows = response
510 .result_set_meta_data
511 .as_ref()
512 .and_then(|m| m.num_rows)
513 .unwrap_or(response.data.len() as i64);
514
515 let num_partitions = response
516 .result_set_meta_data
517 .as_ref()
518 .map(|m| m.partition_info.len())
519 .unwrap_or(1)
520 .max(1);
521
522 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
523 step_progress.set_substatus("parsing");
524
525 all_examples.extend(rejected_examples_from_response(&response)?);
526
527 if num_partitions > 1 {
528 let statement_handle = response
529 .statement_handle
530 .as_ref()
531 .context("response has multiple partitions but no statementHandle")?;
532
533 for partition in 1..num_partitions {
534 step_progress.set_substatus(format!(
535 "fetching partition {}/{}",
536 partition + 1,
537 num_partitions
538 ));
539
540 let partition_response = fetch_partition(
541 http_client.clone(),
542 &base_url,
543 &token,
544 statement_handle,
545 partition,
546 )
547 .await?;
548
549 all_examples.extend(rejected_examples_from_response(&partition_response)?);
550 }
551 }
552
553 step_progress.set_substatus("done");
554 }
555
556 Ok(all_examples)
557}
558
559fn rejected_examples_from_response(
560 response: &SnowflakeStatementResponse,
561) -> Result<impl Iterator<Item = Example> + '_> {
562 if let Some(code) = &response.code {
563 if code != SNOWFLAKE_SUCCESS_CODE {
564 anyhow::bail!(
565 "snowflake sql api returned error code={code} message={}",
566 response.message.as_deref().unwrap_or("<no message>")
567 );
568 }
569 }
570
571 let column_indices = get_column_indices(
572 &response.result_set_meta_data,
573 &[
574 "request_id",
575 "device_id",
576 "time",
577 "input",
578 "prompt",
579 "output",
580 "was_shown",
581 "reason",
582 ],
583 );
584
585 let iter = response
586 .data
587 .iter()
588 .enumerate()
589 .filter_map(move |(row_index, data_row)| {
590 let get_string = |name: &str| -> Option<String> {
591 let index = column_indices.get(name).copied()?;
592 match data_row.get(index)? {
593 JsonValue::String(s) => Some(s.clone()),
594 JsonValue::Null => None,
595 other => Some(other.to_string()),
596 }
597 };
598
599 let get_json = |name: &str| -> Option<JsonValue> {
600 let index = column_indices.get(name).copied()?;
601 let value = data_row.get(index)?;
602 if value.is_null() {
603 return None;
604 }
605 match value {
606 JsonValue::String(s) => serde_json::from_str(s).ok(),
607 other => Some(other.clone()),
608 }
609 };
610
611 let get_bool = |name: &str| -> Option<bool> {
612 let index = column_indices.get(name).copied()?;
613 match data_row.get(index)? {
614 JsonValue::Bool(b) => Some(*b),
615 JsonValue::String(s) => s.parse().ok(),
616 _ => None,
617 }
618 };
619
620 let request_id_str = get_string("request_id");
621 let device_id = get_string("device_id");
622 let time = get_string("time");
623 let input_json = get_json("input");
624 let input: Option<ZetaPromptInput> =
625 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
626 let output = get_string("output");
627 let was_shown = get_bool("was_shown");
628 let reason = get_string("reason");
629
630 match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
631 (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
632 Some(build_rejected_example(
633 request_id,
634 device_id,
635 time,
636 input,
637 output,
638 was_shown,
639 reason,
640 ))
641 }
642 _ => {
643 log::warn!(
644 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
645 request_id_str.is_some(),
646 device_id.is_some(),
647 time.is_some(),
648 input_json.is_some(),
649 output.is_some(),
650 was_shown.is_some(),
651 reason.is_some()
652 );
653 None
654 }
655 }
656 });
657
658 Ok(iter)
659}
660
661fn build_rejected_example(
662 request_id: String,
663 device_id: String,
664 time: String,
665 input: ZetaPromptInput,
666 output: String,
667 was_shown: bool,
668 reason: String,
669) -> Example {
670 let events: Vec<CapturedEvent> = input
671 .events
672 .iter()
673 .map(|event| match event.as_ref() {
674 zeta_prompt::Event::BufferChange {
675 path,
676 old_path,
677 diff,
678 predicted,
679 in_open_source_repo,
680 } => CapturedEvent {
681 path: path.clone(),
682 old_path: old_path.clone(),
683 diff: diff.clone(),
684 predicted: *predicted,
685 in_open_source_repo: *in_open_source_repo,
686 },
687 })
688 .collect();
689
690 let related_files: Vec<CapturedRelatedFile> = input
691 .related_files
692 .iter()
693 .map(|rf| CapturedRelatedFile {
694 path: rf.path.clone(),
695 max_row: rf.max_row,
696 excerpts: rf
697 .excerpts
698 .iter()
699 .map(|e| CapturedRelatedExcerpt {
700 row_range: e.row_range.clone(),
701 text: e.text.to_string(),
702 })
703 .collect(),
704 })
705 .collect();
706
707 let cursor_excerpt = input.cursor_excerpt.as_ref();
708 let cursor_offset = input.cursor_offset_in_excerpt;
709
710 let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset);
711
712 let mut edit_history = String::new();
713 for event in &input.events {
714 zeta_prompt::write_event(&mut edit_history, event);
715 edit_history.push('\n');
716 }
717
718 let rejected_patch = build_rejected_patch(
719 &input.cursor_path,
720 cursor_excerpt,
721 &input.editable_range_in_excerpt,
722 &output,
723 );
724
725 let spec = ExampleSpec {
726 name: request_id.clone(),
727 repository_url: String::new(),
728 revision: String::new(),
729 tags: vec![format!("rejection:{}", reason.to_lowercase())],
730 reasoning: None,
731 uncommitted_diff: String::new(),
732 cursor_path: input.cursor_path.clone(),
733 cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
734 edit_history,
735 expected_patches: Vec::new(),
736 rejected_patch: Some(rejected_patch),
737 captured_prompt_input: Some(CapturedPromptInput {
738 cursor_file_content: cursor_excerpt.to_string(),
739 cursor_offset,
740 cursor_row,
741 cursor_column,
742 events,
743 related_files,
744 }),
745 telemetry: Some(TelemetrySource {
746 request_id,
747 device_id,
748 time,
749 rejection_reason: reason,
750 was_shown,
751 }),
752 };
753
754 Example {
755 spec,
756 prompt_inputs: None,
757 prompt: None,
758 predictions: Vec::new(),
759 score: Vec::new(),
760 qa: Vec::new(),
761 state: None,
762 }
763}
764
765fn compute_row_column(text: &str, offset: usize) -> (u32, u32) {
766 let mut row = 0u32;
767 let mut last_newline_offset = 0;
768 for (i, c) in text.char_indices() {
769 if i >= offset {
770 break;
771 }
772 if c == '\n' {
773 row += 1;
774 last_newline_offset = i + 1;
775 }
776 }
777 let column = (offset - last_newline_offset) as u32;
778 (row, column)
779}
780
781fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
782 let before = &excerpt[..cursor_offset.min(excerpt.len())];
783 let after = &excerpt[cursor_offset.min(excerpt.len())..];
784 format!("{}[CURSOR_POSITION]{}", before, after)
785}
786
787fn build_rejected_patch(
788 cursor_path: &std::path::Path,
789 cursor_excerpt: &str,
790 editable_range: &std::ops::Range<usize>,
791 model_output: &str,
792) -> String {
793 let old_text = &cursor_excerpt[editable_range.clone()];
794
795 let editable_start_row = cursor_excerpt[..editable_range.start]
796 .chars()
797 .filter(|&c| c == '\n')
798 .count() as u32;
799
800 let diff_body = language::unified_diff_with_offsets(
801 old_text,
802 model_output,
803 editable_start_row,
804 editable_start_row,
805 );
806
807 let mut patch = String::new();
808 writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
809 writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
810 patch.push_str(&diff_body);
811 patch
812}
813
814fn get_column_indices(
815 meta: &Option<SnowflakeResultSetMetaData>,
816 names: &[&str],
817) -> std::collections::HashMap<String, usize> {
818 let mut indices = std::collections::HashMap::new();
819 if let Some(meta) = meta {
820 for (index, col) in meta.row_type.iter().enumerate() {
821 for &name in names {
822 if col.name.eq_ignore_ascii_case(name) {
823 indices.insert(name.to_string(), index);
824 }
825 }
826 }
827 }
828 indices
829}