1use anyhow::{Context as _, Result};
2use flate2::read::GzDecoder;
3use gpui::BackgroundExecutor;
4use http_client::{AsyncBody, HttpClient, Method, Request};
5use indoc::indoc;
6use serde::Deserialize;
7use serde_json::{Value as JsonValue, json};
8use std::io::Read;
9use std::sync::Arc;
10use std::time::Duration;
11
12use zeta_prompt::ZetaPromptInput;
13
14use crate::example::Example;
15use crate::progress::{InfoStyle, Progress, Step};
16use edit_prediction::example_spec::{
17 CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec,
18 TelemetrySource,
19};
20use std::fmt::Write as _;
21
22const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
23const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
24const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
25const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
26const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
27
28const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
29const POLL_INTERVAL: Duration = Duration::from_secs(2);
30const MAX_POLL_ATTEMPTS: usize = 120;
31
32/// Parse an input token of the form `captured-after:{timestamp}`.
33pub fn parse_captured_after_input(input: &str) -> Option<&str> {
34 input.strip_prefix("captured-after:")
35}
36
37/// Parse an input token of the form `rejected-after:{timestamp}`.
38pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
39 input.strip_prefix("rejected-after:")
40}
41
42pub async fn fetch_captured_examples_after(
43 http_client: Arc<dyn HttpClient>,
44 after_timestamps: &[String],
45 max_rows_per_timestamp: usize,
46 background_executor: BackgroundExecutor,
47) -> Result<Vec<Example>> {
48 if after_timestamps.is_empty() {
49 return Ok(Vec::new());
50 }
51
52 let progress = Progress::global();
53
54 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
55 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
56 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
57 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
58 )?;
59 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
60
61 let mut all_examples = Vec::new();
62
63 for after_date in after_timestamps.iter() {
64 let step_progress_name = format!(">{after_date}");
65 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
66 step_progress.set_substatus("querying");
67
68 let statement = indoc! {r#"
69 SELECT
70 event_properties:example AS example
71 FROM events
72 WHERE event_type = ?
73 AND time > TRY_TO_TIMESTAMP_NTZ(?)
74 ORDER BY time ASC
75 LIMIT ?
76 "#};
77
78 let request = json!({
79 "statement": statement,
80 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
81 "database": "EVENTS",
82 "schema": "PUBLIC",
83 "warehouse": "DBT",
84 "role": role,
85 "bindings": {
86 "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
87 "2": { "type": "TEXT", "value": after_date },
88 "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
89 }
90 });
91
92 let response = run_sql_with_polling(
93 http_client.clone(),
94 &base_url,
95 &token,
96 &request,
97 &step_progress,
98 background_executor.clone(),
99 )
100 .await?;
101
102 let total_rows = response
103 .result_set_meta_data
104 .as_ref()
105 .and_then(|m| m.num_rows)
106 .unwrap_or(response.data.len() as i64);
107
108 let num_partitions = response
109 .result_set_meta_data
110 .as_ref()
111 .map(|m| m.partition_info.len())
112 .unwrap_or(1)
113 .max(1);
114
115 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
116 step_progress.set_substatus("parsing");
117
118 all_examples.extend(examples_from_response(&response)?);
119
120 if num_partitions > 1 {
121 let statement_handle = response
122 .statement_handle
123 .as_ref()
124 .context("response has multiple partitions but no statementHandle")?;
125
126 for partition in 1..num_partitions {
127 step_progress.set_substatus(format!(
128 "fetching partition {}/{}",
129 partition + 1,
130 num_partitions
131 ));
132
133 let partition_response = fetch_partition(
134 http_client.clone(),
135 &base_url,
136 &token,
137 statement_handle,
138 partition,
139 )
140 .await?;
141
142 all_examples.extend(examples_from_response(&partition_response)?);
143 }
144 }
145
146 step_progress.set_substatus("done");
147 }
148
149 Ok(all_examples)
150}
151
152#[derive(Debug, Clone, Deserialize)]
153#[serde(rename_all = "camelCase")]
154struct SnowflakeStatementResponse {
155 #[serde(default)]
156 data: Vec<Vec<JsonValue>>,
157 #[serde(default)]
158 result_set_meta_data: Option<SnowflakeResultSetMetaData>,
159 #[serde(default)]
160 code: Option<String>,
161 #[serde(default)]
162 message: Option<String>,
163 #[serde(default)]
164 statement_handle: Option<String>,
165}
166
167#[derive(Debug, Clone, Deserialize)]
168#[serde(rename_all = "camelCase")]
169struct SnowflakeResultSetMetaData {
170 #[serde(default, rename = "rowType")]
171 row_type: Vec<SnowflakeColumnMeta>,
172 #[serde(default)]
173 num_rows: Option<i64>,
174 #[serde(default)]
175 partition_info: Vec<SnowflakePartitionInfo>,
176}
177
178#[derive(Debug, Clone, Deserialize)]
179#[serde(rename_all = "camelCase")]
180struct SnowflakePartitionInfo {}
181
182#[derive(Debug, Clone, Deserialize)]
183struct SnowflakeColumnMeta {
184 #[serde(default)]
185 name: String,
186}
187
188fn examples_from_response(
189 response: &SnowflakeStatementResponse,
190) -> Result<impl Iterator<Item = Example> + '_> {
191 if let Some(code) = &response.code {
192 if code != SNOWFLAKE_SUCCESS_CODE {
193 anyhow::bail!(
194 "snowflake sql api returned error code={code} message={}",
195 response.message.as_deref().unwrap_or("<no message>")
196 );
197 }
198 }
199
200 let example_index = response
201 .result_set_meta_data
202 .as_ref()
203 .and_then(|m| {
204 m.row_type.iter().enumerate().find_map(|(index, col)| {
205 if col.name.eq_ignore_ascii_case("example") {
206 Some(index)
207 } else {
208 None
209 }
210 })
211 })
212 .unwrap_or(0);
213
214 let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
215 let Some(example_value) = data_row.get(example_index) else {
216 return None;
217 };
218 if example_value.is_null() {
219 return None;
220 }
221
222 let parse_result = match example_value {
223 JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
224 _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
225 };
226
227 match parse_result {
228 Ok(spec) => Some(Example {
229 spec,
230 prompt_inputs: None,
231 prompt: None,
232 predictions: Vec::new(),
233 score: Vec::new(),
234 state: None,
235 }),
236 Err(error) => {
237 let raw_json = serde_json::to_string_pretty(example_value)
238 .unwrap_or_else(|_| "<failed to serialize json>".to_string());
239 log::error!(
240 "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
241 );
242 None
243 }
244 }
245 });
246
247 Ok(iter)
248}
249
250async fn run_sql_with_polling(
251 http_client: Arc<dyn HttpClient>,
252 base_url: &str,
253 token: &str,
254 request: &serde_json::Value,
255 step_progress: &crate::progress::StepProgress,
256 background_executor: BackgroundExecutor,
257) -> Result<SnowflakeStatementResponse> {
258 let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
259
260 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
261 let statement_handle = response
262 .statement_handle
263 .as_ref()
264 .context("async query response missing statementHandle")?
265 .clone();
266
267 for attempt in 1..=MAX_POLL_ATTEMPTS {
268 step_progress.set_substatus(format!("polling ({attempt})"));
269
270 background_executor.timer(POLL_INTERVAL).await;
271
272 response =
273 fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
274
275 if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
276 break;
277 }
278 }
279
280 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
281 anyhow::bail!(
282 "query still running after {} poll attempts ({} seconds)",
283 MAX_POLL_ATTEMPTS,
284 MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
285 );
286 }
287 }
288
289 Ok(response)
290}
291
292async fn fetch_partition(
293 http_client: Arc<dyn HttpClient>,
294 base_url: &str,
295 token: &str,
296 statement_handle: &str,
297 partition: usize,
298) -> Result<SnowflakeStatementResponse> {
299 let url = format!(
300 "{}/api/v2/statements/{}?partition={}",
301 base_url.trim_end_matches('/'),
302 statement_handle,
303 partition
304 );
305
306 let http_request = Request::builder()
307 .method(Method::GET)
308 .uri(url.as_str())
309 .header("Authorization", format!("Bearer {token}"))
310 .header(
311 "X-Snowflake-Authorization-Token-Type",
312 "PROGRAMMATIC_ACCESS_TOKEN",
313 )
314 .header("Accept", "application/json")
315 .header("Accept-Encoding", "gzip")
316 .header("User-Agent", "edit_prediction_cli")
317 .body(AsyncBody::empty())?;
318
319 let response = http_client
320 .send(http_request)
321 .await
322 .context("failed to send partition request to Snowflake SQL API")?;
323
324 let status = response.status();
325 let content_encoding = response
326 .headers()
327 .get("content-encoding")
328 .and_then(|v| v.to_str().ok())
329 .map(|s| s.to_lowercase());
330
331 let body_bytes = {
332 use futures::AsyncReadExt as _;
333
334 let mut body = response.into_body();
335 let mut bytes = Vec::new();
336 body.read_to_end(&mut bytes)
337 .await
338 .context("failed to read Snowflake SQL API partition response body")?;
339 bytes
340 };
341
342 let body_bytes = if content_encoding.as_deref() == Some("gzip") {
343 let mut decoder = GzDecoder::new(&body_bytes[..]);
344 let mut decompressed = Vec::new();
345 decoder
346 .read_to_end(&mut decompressed)
347 .context("failed to decompress gzip response")?;
348 decompressed
349 } else {
350 body_bytes
351 };
352
353 if !status.is_success() && status.as_u16() != 202 {
354 let body_text = String::from_utf8_lossy(&body_bytes);
355 anyhow::bail!(
356 "snowflake sql api partition request http {}: {}",
357 status.as_u16(),
358 body_text
359 );
360 }
361
362 if body_bytes.is_empty() {
363 anyhow::bail!(
364 "snowflake sql api partition {} returned empty response body (http {})",
365 partition,
366 status.as_u16()
367 );
368 }
369
370 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
371 let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
372 format!(
373 "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
374 partition,
375 status.as_u16(),
376 body_preview
377 )
378 })
379}
380
381async fn run_sql(
382 http_client: Arc<dyn HttpClient>,
383 base_url: &str,
384 token: &str,
385 request: &serde_json::Value,
386) -> Result<SnowflakeStatementResponse> {
387 let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
388
389 let request_body =
390 serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
391
392 let http_request = Request::builder()
393 .method(Method::POST)
394 .uri(url.as_str())
395 .header("Authorization", format!("Bearer {token}"))
396 .header(
397 "X-Snowflake-Authorization-Token-Type",
398 "PROGRAMMATIC_ACCESS_TOKEN",
399 )
400 .header("Content-Type", "application/json")
401 .header("Accept", "application/json")
402 .header("User-Agent", "edit_prediction_cli")
403 .body(AsyncBody::from(request_body.clone()))?;
404
405 let response = http_client
406 .send(http_request)
407 .await
408 .context("failed to send request to Snowflake SQL API")?;
409
410 let status = response.status();
411 let body_bytes = {
412 use futures::AsyncReadExt as _;
413
414 let mut body = response.into_body();
415 let mut bytes = Vec::new();
416 body.read_to_end(&mut bytes)
417 .await
418 .context("failed to read Snowflake SQL API response body")?;
419 bytes
420 };
421
422 if !status.is_success() && status.as_u16() != 202 {
423 let body_text = String::from_utf8_lossy(&body_bytes);
424 anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
425 }
426
427 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
428 .context("failed to parse Snowflake SQL API response JSON")
429}
430
431pub async fn fetch_rejected_examples_after(
432 http_client: Arc<dyn HttpClient>,
433 after_timestamps: &[String],
434 max_rows_per_timestamp: usize,
435 background_executor: BackgroundExecutor,
436) -> Result<Vec<Example>> {
437 if after_timestamps.is_empty() {
438 return Ok(Vec::new());
439 }
440
441 let progress = Progress::global();
442
443 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
444 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
445 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
446 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
447 )?;
448 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
449
450 let mut all_examples = Vec::new();
451
452 for after_date in after_timestamps.iter() {
453 let step_progress_name = format!("rejected>{after_date}");
454 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
455 step_progress.set_substatus("querying");
456
457 // Join rejected events with their corresponding request events to get the full context.
458 // We filter for V3 sampling data which contains the structured input we need.
459 // We also filter for predictions that were actually shown to the user (was_shown = true)
460 // to focus on explicit user rejections rather than implicit cancellations.
461 let statement = indoc! {r#"
462 SELECT
463 req.event_properties:request_id::string AS request_id,
464 req.device_id::string AS device_id,
465 req.time::string AS time,
466 req.event_properties:input AS input,
467 req.event_properties:prompt::string AS prompt,
468 req.event_properties:output::string AS output,
469 rej.event_properties:was_shown::boolean AS was_shown,
470 rej.event_properties:reason::string AS reason
471 FROM events req
472 INNER JOIN events rej
473 ON req.event_properties:request_id = rej.event_properties:request_id
474 WHERE req.event_type = ?
475 AND rej.event_type = ?
476 AND req.event_properties:version = 'V3'
477 AND rej.event_properties:was_shown = true
478 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
479 ORDER BY req.time ASC
480 LIMIT ?
481 "#};
482
483 let request = json!({
484 "statement": statement,
485 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
486 "database": "EVENTS",
487 "schema": "PUBLIC",
488 "warehouse": "DBT",
489 "role": role,
490 "bindings": {
491 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
492 "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
493 "3": { "type": "TEXT", "value": after_date },
494 "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
495 }
496 });
497
498 let response = run_sql_with_polling(
499 http_client.clone(),
500 &base_url,
501 &token,
502 &request,
503 &step_progress,
504 background_executor.clone(),
505 )
506 .await?;
507
508 let total_rows = response
509 .result_set_meta_data
510 .as_ref()
511 .and_then(|m| m.num_rows)
512 .unwrap_or(response.data.len() as i64);
513
514 let num_partitions = response
515 .result_set_meta_data
516 .as_ref()
517 .map(|m| m.partition_info.len())
518 .unwrap_or(1)
519 .max(1);
520
521 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
522 step_progress.set_substatus("parsing");
523
524 all_examples.extend(rejected_examples_from_response(&response)?);
525
526 if num_partitions > 1 {
527 let statement_handle = response
528 .statement_handle
529 .as_ref()
530 .context("response has multiple partitions but no statementHandle")?;
531
532 for partition in 1..num_partitions {
533 step_progress.set_substatus(format!(
534 "fetching partition {}/{}",
535 partition + 1,
536 num_partitions
537 ));
538
539 let partition_response = fetch_partition(
540 http_client.clone(),
541 &base_url,
542 &token,
543 statement_handle,
544 partition,
545 )
546 .await?;
547
548 all_examples.extend(rejected_examples_from_response(&partition_response)?);
549 }
550 }
551
552 step_progress.set_substatus("done");
553 }
554
555 Ok(all_examples)
556}
557
558fn rejected_examples_from_response(
559 response: &SnowflakeStatementResponse,
560) -> Result<impl Iterator<Item = Example> + '_> {
561 if let Some(code) = &response.code {
562 if code != SNOWFLAKE_SUCCESS_CODE {
563 anyhow::bail!(
564 "snowflake sql api returned error code={code} message={}",
565 response.message.as_deref().unwrap_or("<no message>")
566 );
567 }
568 }
569
570 let column_indices = get_column_indices(
571 &response.result_set_meta_data,
572 &[
573 "request_id",
574 "device_id",
575 "time",
576 "input",
577 "prompt",
578 "output",
579 "was_shown",
580 "reason",
581 ],
582 );
583
584 let iter = response
585 .data
586 .iter()
587 .enumerate()
588 .filter_map(move |(row_index, data_row)| {
589 let get_string = |name: &str| -> Option<String> {
590 let index = column_indices.get(name).copied()?;
591 match data_row.get(index)? {
592 JsonValue::String(s) => Some(s.clone()),
593 JsonValue::Null => None,
594 other => Some(other.to_string()),
595 }
596 };
597
598 let get_json = |name: &str| -> Option<JsonValue> {
599 let index = column_indices.get(name).copied()?;
600 let value = data_row.get(index)?;
601 if value.is_null() {
602 return None;
603 }
604 match value {
605 JsonValue::String(s) => serde_json::from_str(s).ok(),
606 other => Some(other.clone()),
607 }
608 };
609
610 let get_bool = |name: &str| -> Option<bool> {
611 let index = column_indices.get(name).copied()?;
612 match data_row.get(index)? {
613 JsonValue::Bool(b) => Some(*b),
614 JsonValue::String(s) => s.parse().ok(),
615 _ => None,
616 }
617 };
618
619 let request_id_str = get_string("request_id");
620 let device_id = get_string("device_id");
621 let time = get_string("time");
622 let input_json = get_json("input");
623 let input: Option<ZetaPromptInput> =
624 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
625 let output = get_string("output");
626 let was_shown = get_bool("was_shown");
627 let reason = get_string("reason");
628
629 match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
630 (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
631 Some(build_rejected_example(
632 request_id,
633 device_id,
634 time,
635 input,
636 output,
637 was_shown,
638 reason,
639 ))
640 }
641 _ => {
642 log::warn!(
643 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
644 request_id_str.is_some(),
645 device_id.is_some(),
646 time.is_some(),
647 input_json.is_some(),
648 output.is_some(),
649 was_shown.is_some(),
650 reason.is_some()
651 );
652 None
653 }
654 }
655 });
656
657 Ok(iter)
658}
659
660fn build_rejected_example(
661 request_id: String,
662 device_id: String,
663 time: String,
664 input: ZetaPromptInput,
665 output: String,
666 was_shown: bool,
667 reason: String,
668) -> Example {
669 let events: Vec<CapturedEvent> = input
670 .events
671 .iter()
672 .map(|event| match event.as_ref() {
673 zeta_prompt::Event::BufferChange {
674 path,
675 old_path,
676 diff,
677 predicted,
678 in_open_source_repo,
679 } => CapturedEvent {
680 path: path.clone(),
681 old_path: old_path.clone(),
682 diff: diff.clone(),
683 predicted: *predicted,
684 in_open_source_repo: *in_open_source_repo,
685 },
686 })
687 .collect();
688
689 let related_files: Vec<CapturedRelatedFile> = input
690 .related_files
691 .iter()
692 .map(|rf| CapturedRelatedFile {
693 path: rf.path.clone(),
694 max_row: rf.max_row,
695 excerpts: rf
696 .excerpts
697 .iter()
698 .map(|e| CapturedRelatedExcerpt {
699 row_range: e.row_range.clone(),
700 text: e.text.to_string(),
701 })
702 .collect(),
703 })
704 .collect();
705
706 let cursor_excerpt = input.cursor_excerpt.as_ref();
707 let cursor_offset = input.cursor_offset_in_excerpt;
708
709 let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset);
710
711 let mut edit_history = String::new();
712 for event in &input.events {
713 zeta_prompt::write_event(&mut edit_history, event);
714 edit_history.push('\n');
715 }
716
717 let rejected_patch = build_rejected_patch(
718 &input.cursor_path,
719 cursor_excerpt,
720 &input.editable_range_in_excerpt,
721 &output,
722 );
723
724 let spec = ExampleSpec {
725 name: request_id.clone(),
726 repository_url: String::new(),
727 revision: String::new(),
728 tags: vec![format!("rejection:{}", reason.to_lowercase())],
729 reasoning: None,
730 uncommitted_diff: String::new(),
731 cursor_path: input.cursor_path.clone(),
732 cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
733 edit_history,
734 expected_patches: Vec::new(),
735 rejected_patch: Some(rejected_patch),
736 captured_prompt_input: Some(CapturedPromptInput {
737 cursor_file_content: cursor_excerpt.to_string(),
738 cursor_offset,
739 cursor_row,
740 cursor_column,
741 events,
742 related_files,
743 }),
744 telemetry: Some(TelemetrySource {
745 request_id,
746 device_id,
747 time,
748 rejection_reason: reason,
749 was_shown,
750 }),
751 };
752
753 Example {
754 spec,
755 prompt_inputs: None,
756 prompt: None,
757 predictions: Vec::new(),
758 score: Vec::new(),
759 state: None,
760 }
761}
762
763fn compute_row_column(text: &str, offset: usize) -> (u32, u32) {
764 let mut row = 0u32;
765 let mut last_newline_offset = 0;
766 for (i, c) in text.char_indices() {
767 if i >= offset {
768 break;
769 }
770 if c == '\n' {
771 row += 1;
772 last_newline_offset = i + 1;
773 }
774 }
775 let column = (offset - last_newline_offset) as u32;
776 (row, column)
777}
778
779fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
780 let before = &excerpt[..cursor_offset.min(excerpt.len())];
781 let after = &excerpt[cursor_offset.min(excerpt.len())..];
782 format!("{}[CURSOR_POSITION]{}", before, after)
783}
784
785fn build_rejected_patch(
786 cursor_path: &std::path::Path,
787 cursor_excerpt: &str,
788 editable_range: &std::ops::Range<usize>,
789 model_output: &str,
790) -> String {
791 let old_text = &cursor_excerpt[editable_range.clone()];
792
793 let editable_start_row = cursor_excerpt[..editable_range.start]
794 .chars()
795 .filter(|&c| c == '\n')
796 .count() as u32;
797
798 let diff_body = language::unified_diff_with_offsets(
799 old_text,
800 model_output,
801 editable_start_row,
802 editable_start_row,
803 );
804
805 let mut patch = String::new();
806 writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
807 writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
808 patch.push_str(&diff_body);
809 patch
810}
811
812fn get_column_indices(
813 meta: &Option<SnowflakeResultSetMetaData>,
814 names: &[&str],
815) -> std::collections::HashMap<String, usize> {
816 let mut indices = std::collections::HashMap::new();
817 if let Some(meta) = meta {
818 for (index, col) in meta.row_type.iter().enumerate() {
819 for &name in names {
820 if col.name.eq_ignore_ascii_case(name) {
821 indices.insert(name.to_string(), index);
822 }
823 }
824 }
825 }
826 indices
827}