1use anyhow::{Context as _, Result};
2use flate2::read::GzDecoder;
3use gpui::BackgroundExecutor;
4use http_client::{AsyncBody, HttpClient, Method, Request};
5use indoc::indoc;
6use serde::Deserialize;
7use serde_json::{Value as JsonValue, json};
8use std::io::Read;
9use std::sync::Arc;
10use std::time::Duration;
11use telemetry_events::EditPredictionRating;
12
13use zeta_prompt::ZetaPromptInput;
14
15use crate::example::Example;
16use crate::progress::{InfoStyle, Progress, Step};
17use crate::sync_deployments::EDIT_PREDICTION_DEPLOYMENT_EVENT;
18use edit_prediction::example_spec::{
19 CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec,
20 TelemetrySource,
21};
22use std::fmt::Write as _;
23
24pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
25pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
26const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
27const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
28const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
29const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
30
31/// Minimum Zed version for filtering captured examples.
32/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
33/// where `zed_version >= 0.224.1`.
34#[derive(Clone, Copy, Debug)]
35pub struct MinCaptureVersion {
36 pub minor: u32,
37 pub patch: u32,
38}
39
40const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
41pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
42pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
43
44/// Parse an input token of the form `captured-after:{timestamp}`.
45pub fn parse_captured_after_input(input: &str) -> Option<&str> {
46 input.strip_prefix("captured-after:")
47}
48
49/// Parse an input token of the form `rejected-after:{timestamp}`.
50pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
51 input.strip_prefix("rejected-after:")
52}
53
54/// Parse an input token of the form `requested-after:{timestamp}`.
55pub fn parse_requested_after_input(input: &str) -> Option<&str> {
56 input.strip_prefix("requested-after:")
57}
58
59/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
60/// or `rated-negative-after:{timestamp}`.
61/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
62pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
63 if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
64 Some((timestamp, Some(EditPredictionRating::Positive)))
65 } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
66 Some((timestamp, Some(EditPredictionRating::Negative)))
67 } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
68 Some((timestamp, None))
69 } else {
70 None
71 }
72}
73
74pub async fn fetch_captured_examples_after(
75 http_client: Arc<dyn HttpClient>,
76 after_timestamps: &[String],
77 max_rows_per_timestamp: usize,
78 offset: usize,
79 background_executor: BackgroundExecutor,
80 _min_capture_version: Option<MinCaptureVersion>,
81) -> Result<Vec<Example>> {
82 if after_timestamps.is_empty() {
83 return Ok(Vec::new());
84 }
85
86 let progress = Progress::global();
87
88 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
89 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
90 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
91 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
92 )?;
93 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
94
95 let mut all_examples = Vec::new();
96
97 for after_date in after_timestamps.iter() {
98 let step_progress_name = format!(">{after_date}");
99 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
100 step_progress.set_substatus("querying");
101
102 let statement = indoc! {r#"
103 SELECT
104 event_properties:example AS example
105 FROM events
106 WHERE event_type = ?
107 AND time > TRY_TO_TIMESTAMP_NTZ(?)
108 AND event_properties:can_collect_data = true
109 ORDER BY time ASC
110 LIMIT ?
111 OFFSET ?
112 "#};
113
114 let request = json!({
115 "statement": statement,
116 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
117 "database": "EVENTS",
118 "schema": "PUBLIC",
119 "warehouse": "DBT",
120 "role": role,
121 "bindings": {
122 "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
123 "2": { "type": "TEXT", "value": after_date },
124 "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
125 "4": { "type": "FIXED", "value": offset.to_string() }
126 }
127 });
128
129 let response = run_sql_with_polling(
130 http_client.clone(),
131 &base_url,
132 &token,
133 &request,
134 &step_progress,
135 background_executor.clone(),
136 )
137 .await?;
138
139 let total_rows = response
140 .result_set_meta_data
141 .as_ref()
142 .and_then(|m| m.num_rows)
143 .unwrap_or(response.data.len() as i64);
144
145 let num_partitions = response
146 .result_set_meta_data
147 .as_ref()
148 .map(|m| m.partition_info.len())
149 .unwrap_or(1)
150 .max(1);
151
152 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
153 step_progress.set_substatus("parsing");
154
155 let example_index = response
156 .result_set_meta_data
157 .as_ref()
158 .and_then(|m| {
159 m.row_type.iter().enumerate().find_map(|(index, col)| {
160 if col.name.eq_ignore_ascii_case("example") {
161 Some(index)
162 } else {
163 None
164 }
165 })
166 })
167 .unwrap_or(0);
168
169 all_examples.extend(examples_from_response(&response, example_index)?);
170
171 if num_partitions > 1 {
172 let statement_handle = response
173 .statement_handle
174 .as_ref()
175 .context("response has multiple partitions but no statementHandle")?;
176
177 for partition in 1..num_partitions {
178 step_progress.set_substatus(format!(
179 "fetching partition {}/{}",
180 partition + 1,
181 num_partitions
182 ));
183
184 let partition_response = fetch_partition(
185 http_client.clone(),
186 &base_url,
187 &token,
188 statement_handle,
189 partition,
190 )
191 .await?;
192
193 all_examples.extend(examples_from_response(&partition_response, example_index)?);
194 }
195 }
196
197 step_progress.set_substatus("done");
198 }
199
200 Ok(all_examples)
201}
202
203#[derive(Debug, Clone, Deserialize)]
204#[serde(rename_all = "camelCase")]
205pub(crate) struct SnowflakeStatementResponse {
206 #[serde(default)]
207 pub(crate) data: Vec<Vec<JsonValue>>,
208 #[serde(default)]
209 pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
210 #[serde(default)]
211 pub(crate) code: Option<String>,
212 #[serde(default)]
213 pub(crate) message: Option<String>,
214 #[serde(default)]
215 pub(crate) statement_handle: Option<String>,
216}
217
218#[derive(Debug, Clone, Deserialize)]
219#[serde(rename_all = "camelCase")]
220pub(crate) struct SnowflakeResultSetMetaData {
221 #[serde(default, rename = "rowType")]
222 row_type: Vec<SnowflakeColumnMeta>,
223 #[serde(default)]
224 num_rows: Option<i64>,
225 #[serde(default)]
226 partition_info: Vec<SnowflakePartitionInfo>,
227}
228
229#[derive(Debug, Clone, Deserialize)]
230#[serde(rename_all = "camelCase")]
231struct SnowflakePartitionInfo {}
232
233#[derive(Debug, Clone, Deserialize)]
234struct SnowflakeColumnMeta {
235 #[serde(default)]
236 name: String,
237}
238
239fn examples_from_response(
240 response: &SnowflakeStatementResponse,
241 example_index: usize,
242) -> Result<impl Iterator<Item = Example> + '_> {
243 if let Some(code) = &response.code {
244 if code != SNOWFLAKE_SUCCESS_CODE {
245 anyhow::bail!(
246 "snowflake sql api returned error code={code} message={}",
247 response.message.as_deref().unwrap_or("<no message>")
248 );
249 }
250 }
251
252 let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
253 let Some(example_value) = data_row.get(example_index) else {
254 return None;
255 };
256 if example_value.is_null() {
257 return None;
258 }
259
260 let parse_result = match example_value {
261 JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
262 _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
263 };
264
265 match parse_result {
266 Ok(spec) => Some(Example {
267 spec,
268 prompt_inputs: None,
269 prompt: None,
270 predictions: Vec::new(),
271 score: Vec::new(),
272 qa: Vec::new(),
273 state: None,
274 }),
275 Err(error) => {
276 let raw_json = serde_json::to_string_pretty(example_value)
277 .unwrap_or_else(|_| "<failed to serialize json>".to_string());
278 log::error!(
279 "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
280 );
281 None
282 }
283 }
284 });
285
286 Ok(iter)
287}
288
289async fn run_sql_with_polling(
290 http_client: Arc<dyn HttpClient>,
291 base_url: &str,
292 token: &str,
293 request: &serde_json::Value,
294 step_progress: &crate::progress::StepProgress,
295 background_executor: BackgroundExecutor,
296) -> Result<SnowflakeStatementResponse> {
297 let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
298
299 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
300 let statement_handle = response
301 .statement_handle
302 .as_ref()
303 .context("async query response missing statementHandle")?
304 .clone();
305
306 for attempt in 1..=MAX_POLL_ATTEMPTS {
307 step_progress.set_substatus(format!("polling ({attempt})"));
308
309 background_executor.timer(POLL_INTERVAL).await;
310
311 response =
312 fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
313
314 if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
315 break;
316 }
317 }
318
319 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
320 anyhow::bail!(
321 "query still running after {} poll attempts ({} seconds)",
322 MAX_POLL_ATTEMPTS,
323 MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
324 );
325 }
326 }
327
328 Ok(response)
329}
330
331pub(crate) async fn fetch_partition(
332 http_client: Arc<dyn HttpClient>,
333 base_url: &str,
334 token: &str,
335 statement_handle: &str,
336 partition: usize,
337) -> Result<SnowflakeStatementResponse> {
338 let url = format!(
339 "{}/api/v2/statements/{}?partition={}",
340 base_url.trim_end_matches('/'),
341 statement_handle,
342 partition
343 );
344
345 let http_request = Request::builder()
346 .method(Method::GET)
347 .uri(url.as_str())
348 .header("Authorization", format!("Bearer {token}"))
349 .header(
350 "X-Snowflake-Authorization-Token-Type",
351 "PROGRAMMATIC_ACCESS_TOKEN",
352 )
353 .header("Accept", "application/json")
354 .header("Accept-Encoding", "gzip")
355 .header("User-Agent", "edit_prediction_cli")
356 .body(AsyncBody::empty())?;
357
358 let response = http_client
359 .send(http_request)
360 .await
361 .context("failed to send partition request to Snowflake SQL API")?;
362
363 let status = response.status();
364 let content_encoding = response
365 .headers()
366 .get("content-encoding")
367 .and_then(|v| v.to_str().ok())
368 .map(|s| s.to_lowercase());
369
370 let body_bytes = {
371 use futures::AsyncReadExt as _;
372
373 let mut body = response.into_body();
374 let mut bytes = Vec::new();
375 body.read_to_end(&mut bytes)
376 .await
377 .context("failed to read Snowflake SQL API partition response body")?;
378 bytes
379 };
380
381 let body_bytes = if content_encoding.as_deref() == Some("gzip") {
382 let mut decoder = GzDecoder::new(&body_bytes[..]);
383 let mut decompressed = Vec::new();
384 decoder
385 .read_to_end(&mut decompressed)
386 .context("failed to decompress gzip response")?;
387 decompressed
388 } else {
389 body_bytes
390 };
391
392 if !status.is_success() && status.as_u16() != 202 {
393 let body_text = String::from_utf8_lossy(&body_bytes);
394 anyhow::bail!(
395 "snowflake sql api partition request http {}: {}",
396 status.as_u16(),
397 body_text
398 );
399 }
400
401 if body_bytes.is_empty() {
402 anyhow::bail!(
403 "snowflake sql api partition {} returned empty response body (http {})",
404 partition,
405 status.as_u16()
406 );
407 }
408
409 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
410 let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
411 format!(
412 "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
413 partition,
414 status.as_u16(),
415 body_preview
416 )
417 })
418}
419
420pub(crate) async fn run_sql(
421 http_client: Arc<dyn HttpClient>,
422 base_url: &str,
423 token: &str,
424 request: &serde_json::Value,
425) -> Result<SnowflakeStatementResponse> {
426 let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
427
428 let request_body =
429 serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
430
431 let http_request = Request::builder()
432 .method(Method::POST)
433 .uri(url.as_str())
434 .header("Authorization", format!("Bearer {token}"))
435 .header(
436 "X-Snowflake-Authorization-Token-Type",
437 "PROGRAMMATIC_ACCESS_TOKEN",
438 )
439 .header("Content-Type", "application/json")
440 .header("Accept", "application/json")
441 .header("User-Agent", "edit_prediction_cli")
442 .body(AsyncBody::from(request_body.clone()))?;
443
444 let response = http_client
445 .send(http_request)
446 .await
447 .context("failed to send request to Snowflake SQL API")?;
448
449 let status = response.status();
450 let body_bytes = {
451 use futures::AsyncReadExt as _;
452
453 let mut body = response.into_body();
454 let mut bytes = Vec::new();
455 body.read_to_end(&mut bytes)
456 .await
457 .context("failed to read Snowflake SQL API response body")?;
458 bytes
459 };
460
461 if !status.is_success() && status.as_u16() != 202 {
462 let body_text = String::from_utf8_lossy(&body_bytes);
463 anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
464 }
465
466 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
467 .context("failed to parse Snowflake SQL API response JSON")
468}
469
470pub async fn fetch_rejected_examples_after(
471 http_client: Arc<dyn HttpClient>,
472 after_timestamps: &[String],
473 max_rows_per_timestamp: usize,
474 offset: usize,
475 background_executor: BackgroundExecutor,
476 min_capture_version: Option<MinCaptureVersion>,
477) -> Result<Vec<Example>> {
478 if after_timestamps.is_empty() {
479 return Ok(Vec::new());
480 }
481
482 let progress = Progress::global();
483
484 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
485 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
486 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
487 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
488 )?;
489 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
490
491 let mut all_examples = Vec::new();
492
493 for after_date in after_timestamps.iter() {
494 let step_progress_name = format!("rejected>{after_date}");
495 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
496 step_progress.set_substatus("querying");
497
498 // Join rejected events with their corresponding request events to get the full context.
499 // We filter for V3 sampling data which contains the structured input we need.
500 // We also filter for predictions that were actually shown to the user (was_shown = true)
501 // to focus on explicit user rejections rather than implicit cancellations.
502 let statement = indoc! {r#"
503 SELECT
504 req.event_properties:request_id::string AS request_id,
505 req.device_id::string AS device_id,
506 req.time::string AS time,
507 req.event_properties:input AS input,
508 req.event_properties:prompt::string AS prompt,
509 req.event_properties:output::string AS output,
510 rej.event_properties:was_shown::boolean AS was_shown,
511 rej.event_properties:reason::string AS reason,
512 req.event_properties:zed_version::string AS zed_version
513 FROM events req
514 INNER JOIN events rej
515 ON req.event_properties:request_id = rej.event_properties:request_id
516 WHERE req.event_type = ?
517 AND rej.event_type = ?
518 AND req.event_properties:version = 'V3'
519 AND rej.event_properties:was_shown = true
520 AND req.event_properties:can_collect_data = true
521 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
522 AND (? IS NULL OR (
523 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
524 OR (
525 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
526 AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
527 )
528 ))
529 ORDER BY req.time ASC
530 LIMIT ?
531 OFFSET ?
532 "#};
533
534 let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
535 let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
536 let min_minor_str_ref = min_minor_str.as_deref();
537 let min_patch_str_ref = min_patch_str.as_deref();
538 let request = json!({
539 "statement": statement,
540 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
541 "database": "EVENTS",
542 "schema": "PUBLIC",
543 "warehouse": "DBT",
544 "role": role,
545 "bindings": {
546 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
547 "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
548 "3": { "type": "TEXT", "value": after_date },
549 "4": { "type": "FIXED", "value": min_minor_str_ref },
550 "5": { "type": "FIXED", "value": min_minor_str_ref },
551 "6": { "type": "FIXED", "value": min_minor_str_ref },
552 "7": { "type": "FIXED", "value": min_patch_str_ref },
553 "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
554 "9": { "type": "FIXED", "value": offset.to_string() }
555 }
556 });
557
558 let response = run_sql_with_polling(
559 http_client.clone(),
560 &base_url,
561 &token,
562 &request,
563 &step_progress,
564 background_executor.clone(),
565 )
566 .await?;
567
568 let total_rows = response
569 .result_set_meta_data
570 .as_ref()
571 .and_then(|m| m.num_rows)
572 .unwrap_or(response.data.len() as i64);
573
574 let num_partitions = response
575 .result_set_meta_data
576 .as_ref()
577 .map(|m| m.partition_info.len())
578 .unwrap_or(1)
579 .max(1);
580
581 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
582 step_progress.set_substatus("parsing");
583
584 let column_indices = get_column_indices(
585 &response.result_set_meta_data,
586 &[
587 "request_id",
588 "device_id",
589 "time",
590 "input",
591 "prompt",
592 "output",
593 "was_shown",
594 "reason",
595 "zed_version",
596 ],
597 );
598
599 all_examples.extend(rejected_examples_from_response(&response, &column_indices)?);
600
601 if num_partitions > 1 {
602 let statement_handle = response
603 .statement_handle
604 .as_ref()
605 .context("response has multiple partitions but no statementHandle")?;
606
607 for partition in 1..num_partitions {
608 step_progress.set_substatus(format!(
609 "fetching partition {}/{}",
610 partition + 1,
611 num_partitions
612 ));
613
614 let partition_response = fetch_partition(
615 http_client.clone(),
616 &base_url,
617 &token,
618 statement_handle,
619 partition,
620 )
621 .await?;
622
623 all_examples.extend(rejected_examples_from_response(
624 &partition_response,
625 &column_indices,
626 )?);
627 }
628 }
629
630 step_progress.set_substatus("done");
631 }
632
633 Ok(all_examples)
634}
635
636pub async fn fetch_requested_examples_after(
637 http_client: Arc<dyn HttpClient>,
638 after_timestamps: &[String],
639 max_rows_per_timestamp: usize,
640 offset: usize,
641 background_executor: BackgroundExecutor,
642 min_capture_version: Option<MinCaptureVersion>,
643) -> Result<Vec<Example>> {
644 if after_timestamps.is_empty() {
645 return Ok(Vec::new());
646 }
647
648 let progress = Progress::global();
649
650 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
651 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
652 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
653 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
654 )?;
655 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
656
657 let mut all_examples = Vec::new();
658
659 for after_date in after_timestamps.iter() {
660 let step_progress_name = format!("requested>{after_date}");
661 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
662 step_progress.set_substatus("querying");
663
664 let statement = indoc! {r#"
665 SELECT
666 req.event_properties:request_id::string AS request_id,
667 req.device_id::string AS device_id,
668 req.time::string AS time,
669 req.event_properties:input AS input,
670 req.event_properties:zed_version::string AS zed_version
671 FROM events req
672 WHERE req.event_type = ?
673 AND req.event_properties:version = 'V3'
674 AND req.event_properties:can_collect_data = true
675 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
676 AND (? IS NULL OR (
677 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
678 OR (
679 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
680 AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
681 )
682 ))
683 ORDER BY req.time ASC
684 LIMIT ?
685 OFFSET ?
686 "#};
687
688 let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
689 let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
690 let min_minor_str_ref = min_minor_str.as_deref();
691 let min_patch_str_ref = min_patch_str.as_deref();
692 let request = json!({
693 "statement": statement,
694 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
695 "database": "EVENTS",
696 "schema": "PUBLIC",
697 "warehouse": "DBT",
698 "role": role,
699 "bindings": {
700 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
701 "2": { "type": "TEXT", "value": after_date },
702 "3": { "type": "FIXED", "value": min_minor_str_ref },
703 "4": { "type": "FIXED", "value": min_minor_str_ref },
704 "5": { "type": "FIXED", "value": min_minor_str_ref },
705 "6": { "type": "FIXED", "value": min_patch_str_ref },
706 "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
707 "8": { "type": "FIXED", "value": offset.to_string() }
708 }
709 });
710
711 let response = run_sql_with_polling(
712 http_client.clone(),
713 &base_url,
714 &token,
715 &request,
716 &step_progress,
717 background_executor.clone(),
718 )
719 .await?;
720
721 let total_rows = response
722 .result_set_meta_data
723 .as_ref()
724 .and_then(|m| m.num_rows)
725 .unwrap_or(response.data.len() as i64);
726
727 let num_partitions = response
728 .result_set_meta_data
729 .as_ref()
730 .map(|m| m.partition_info.len())
731 .unwrap_or(1)
732 .max(1);
733
734 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
735 step_progress.set_substatus("parsing");
736
737 let column_indices = get_column_indices(
738 &response.result_set_meta_data,
739 &["request_id", "device_id", "time", "input", "zed_version"],
740 );
741
742 all_examples.extend(requested_examples_from_response(
743 &response,
744 &column_indices,
745 )?);
746
747 if num_partitions > 1 {
748 let statement_handle = response
749 .statement_handle
750 .as_ref()
751 .context("response has multiple partitions but no statementHandle")?;
752
753 for partition in 1..num_partitions {
754 step_progress.set_substatus(format!(
755 "fetching partition {}/{}",
756 partition + 1,
757 num_partitions
758 ));
759
760 let partition_response = fetch_partition(
761 http_client.clone(),
762 &base_url,
763 &token,
764 statement_handle,
765 partition,
766 )
767 .await?;
768
769 all_examples.extend(requested_examples_from_response(
770 &partition_response,
771 &column_indices,
772 )?);
773 }
774 }
775
776 step_progress.set_substatus("done");
777 }
778
779 Ok(all_examples)
780}
781
782pub async fn fetch_rated_examples_after(
783 http_client: Arc<dyn HttpClient>,
784 inputs: &[(String, Option<EditPredictionRating>)],
785 max_rows_per_timestamp: usize,
786 offset: usize,
787 background_executor: BackgroundExecutor,
788 min_capture_version: Option<MinCaptureVersion>,
789) -> Result<Vec<Example>> {
790 if inputs.is_empty() {
791 return Ok(Vec::new());
792 }
793
794 let progress = Progress::global();
795
796 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
797 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
798 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
799 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
800 )?;
801 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
802
803 let mut all_examples = Vec::new();
804
805 for (after_date, rating_filter) in inputs.iter() {
806 let filter_label = match rating_filter {
807 None => "",
808 Some(EditPredictionRating::Positive) => ":positive",
809 Some(EditPredictionRating::Negative) => ":negative",
810 };
811 let step_progress_name = format!("rated{filter_label}>{after_date}");
812 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
813 step_progress.set_substatus("querying");
814
815 let rating_value = rating_filter.as_ref().map(|r| match r {
816 EditPredictionRating::Positive => "Positive",
817 EditPredictionRating::Negative => "Negative",
818 });
819
820 let statement = indoc! {r#"
821 SELECT
822 rated.event_properties:request_id::string AS request_id,
823 rated.event_properties:inputs AS inputs,
824 rated.event_properties:output::string AS output,
825 rated.event_properties:rating::string AS rating,
826 rated.event_properties:feedback::string AS feedback,
827 rated.device_id::string AS device_id,
828 rated.time::string AS time,
829 deploy.event_properties:experiment_name::string AS experiment_name,
830 deploy.event_properties:environment::string AS environment,
831 rated.event_properties:zed_version::string AS zed_version
832 FROM events rated
833 LEFT JOIN events req
834 ON rated.event_properties:request_id::string = req.event_properties:request_id::string
835 AND req.event_type = ?
836 LEFT JOIN events deploy
837 ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string
838 AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string
839 AND deploy.event_type = ?
840 WHERE rated.event_type = ?
841 AND (? IS NULL OR rated.event_properties:rating::string = ?)
842 AND rated.time > TRY_TO_TIMESTAMP_NTZ(?)
843 AND rated.event_properties:inputs IS NOT NULL
844 AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
845 AND rated.event_properties:output IS NOT NULL
846 AND rated.event_properties:can_collect_data = true
847 AND (? IS NULL OR (
848 TRY_CAST(SPLIT_PART(rated.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
849 OR (
850 TRY_CAST(SPLIT_PART(rated.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
851 AND TRY_CAST(SPLIT_PART(SPLIT_PART(rated.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
852 )
853 ))
854 ORDER BY rated.time ASC
855 LIMIT ?
856 OFFSET ?
857 "#};
858
859 let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
860 let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
861 let min_minor_str_ref = min_minor_str.as_deref();
862 let min_patch_str_ref = min_patch_str.as_deref();
863 let bindings = json!({
864 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
865 "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
866 "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
867 "4": { "type": "TEXT", "value": rating_value },
868 "5": { "type": "TEXT", "value": rating_value },
869 "6": { "type": "TEXT", "value": after_date },
870 "7": { "type": "FIXED", "value": min_minor_str_ref },
871 "8": { "type": "FIXED", "value": min_minor_str_ref },
872 "9": { "type": "FIXED", "value": min_minor_str_ref },
873 "10": { "type": "FIXED", "value": min_patch_str_ref },
874 "11": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
875 "12": { "type": "FIXED", "value": offset.to_string() }
876 });
877
878 let request = json!({
879 "statement": statement,
880 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
881 "database": "EVENTS",
882 "schema": "PUBLIC",
883 "warehouse": "DBT",
884 "role": role,
885 "bindings": bindings
886 });
887
888 let response = run_sql_with_polling(
889 http_client.clone(),
890 &base_url,
891 &token,
892 &request,
893 &step_progress,
894 background_executor.clone(),
895 )
896 .await?;
897
898 let total_rows = response
899 .result_set_meta_data
900 .as_ref()
901 .and_then(|m| m.num_rows)
902 .unwrap_or(response.data.len() as i64);
903
904 let num_partitions = response
905 .result_set_meta_data
906 .as_ref()
907 .map(|m| m.partition_info.len())
908 .unwrap_or(1)
909 .max(1);
910
911 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
912 step_progress.set_substatus("parsing");
913
914 let column_indices = get_column_indices(
915 &response.result_set_meta_data,
916 &[
917 "request_id",
918 "inputs",
919 "output",
920 "rating",
921 "feedback",
922 "device_id",
923 "time",
924 "experiment_name",
925 "environment",
926 "zed_version",
927 ],
928 );
929
930 all_examples.extend(rated_examples_from_response(&response, &column_indices)?);
931
932 if num_partitions > 1 {
933 let statement_handle = response
934 .statement_handle
935 .as_ref()
936 .context("response has multiple partitions but no statementHandle")?;
937
938 for partition in 1..num_partitions {
939 step_progress.set_substatus(format!(
940 "fetching partition {}/{}",
941 partition + 1,
942 num_partitions
943 ));
944
945 let partition_response = fetch_partition(
946 http_client.clone(),
947 &base_url,
948 &token,
949 statement_handle,
950 partition,
951 )
952 .await?;
953
954 all_examples.extend(rated_examples_from_response(
955 &partition_response,
956 &column_indices,
957 )?);
958 }
959 }
960
961 step_progress.set_substatus("done");
962 }
963
964 Ok(all_examples)
965}
966
967fn rated_examples_from_response<'a>(
968 response: &'a SnowflakeStatementResponse,
969 column_indices: &'a std::collections::HashMap<String, usize>,
970) -> Result<impl Iterator<Item = Example> + 'a> {
971 if let Some(code) = &response.code {
972 if code != SNOWFLAKE_SUCCESS_CODE {
973 anyhow::bail!(
974 "snowflake sql api returned error code={code} message={}",
975 response.message.as_deref().unwrap_or("<no message>")
976 );
977 }
978 }
979
980 let iter = response
981 .data
982 .iter()
983 .enumerate()
984 .filter_map(move |(row_index, data_row)| {
985 let get_string = |name: &str| -> Option<String> {
986 let index = column_indices.get(name).copied()?;
987 match data_row.get(index)? {
988 JsonValue::String(s) => Some(s.clone()),
989 JsonValue::Null => None,
990 other => Some(other.to_string()),
991 }
992 };
993
994 let get_json = |name: &str| -> Option<JsonValue> {
995 let index = column_indices.get(name).copied()?;
996 let value = data_row.get(index)?;
997 if value.is_null() {
998 return None;
999 }
1000 match value {
1001 JsonValue::String(s) => serde_json::from_str(s).ok(),
1002 other => Some(other.clone()),
1003 }
1004 };
1005
1006 let request_id = get_string("request_id");
1007 let inputs_json = get_json("inputs");
1008 let inputs: Option<ZetaPromptInput> = match &inputs_json {
1009 Some(v) => match serde_json::from_value(v.clone()) {
1010 Ok(parsed) => Some(parsed),
1011 Err(e) => {
1012 log::warn!(
1013 "skipping row {row_index}: failed to parse inputs - {e}",
1014 );
1015 return None;
1016 }
1017 },
1018 None => None,
1019 };
1020 let output = get_string("output");
1021 let rating = get_string("rating");
1022 let feedback = get_string("feedback").unwrap_or_default();
1023 let device_id = get_string("device_id");
1024 let time = get_string("time");
1025 let experiment_name = get_string("experiment_name");
1026 let environment = get_string("environment");
1027 let zed_version = get_string("zed_version");
1028
1029 match (inputs, output.clone(), rating.clone(), device_id.clone(), time.clone()) {
1030 (Some(inputs), Some(output), Some(rating), Some(device_id), Some(time)) => {
1031 Some(build_rated_example(
1032 request_id,
1033 device_id,
1034 time,
1035 inputs,
1036 output,
1037 rating,
1038 feedback,
1039 experiment_name,
1040 environment,
1041 zed_version,
1042 ))
1043 }
1044 _ => {
1045 log::warn!(
1046 "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} device_id={:?} time={:?}",
1047 inputs_json.is_some(),
1048 output.is_some(),
1049 rating.is_some(),
1050 device_id.is_some(),
1051 time.is_some(),
1052 );
1053 None
1054 }
1055 }
1056 });
1057
1058 Ok(iter)
1059}
1060
1061fn build_rated_example(
1062 request_id: Option<String>,
1063 device_id: String,
1064 time: String,
1065 input: ZetaPromptInput,
1066 output: String,
1067 rating: String,
1068 feedback: String,
1069 experiment_name: Option<String>,
1070 environment: Option<String>,
1071 zed_version: Option<String>,
1072) -> Example {
1073 let parsed_rating = if rating == "Positive" {
1074 EditPredictionRating::Positive
1075 } else {
1076 EditPredictionRating::Negative
1077 };
1078 let is_positive = parsed_rating == EditPredictionRating::Positive;
1079 let request_id = request_id.unwrap_or_else(|| format!("rated-{}-{}", device_id, time));
1080
1081 let mut tags = Vec::with_capacity(3);
1082 tags.push(if is_positive {
1083 "rated:positive".to_string()
1084 } else {
1085 "rated:negative".to_string()
1086 });
1087 if let Some(experiment) = experiment_name {
1088 tags.push(format!("experiment:{experiment}"));
1089 }
1090 if let Some(env) = environment {
1091 tags.push(format!("environment:{env}"));
1092 }
1093
1094 let mut example =
1095 build_example_from_snowflake(request_id, device_id, time, input, tags, None, zed_version);
1096
1097 example.spec.rating = Some(parsed_rating);
1098
1099 if !feedback.is_empty() {
1100 example
1101 .spec
1102 .human_feedback
1103 .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
1104 }
1105
1106 if is_positive {
1107 example.spec.expected_patches = vec![output];
1108 } else {
1109 example.spec.rejected_patch = Some(output);
1110 }
1111
1112 example
1113}
1114
1115fn requested_examples_from_response<'a>(
1116 response: &'a SnowflakeStatementResponse,
1117 column_indices: &'a std::collections::HashMap<String, usize>,
1118) -> Result<impl Iterator<Item = Example> + 'a> {
1119 if let Some(code) = &response.code {
1120 if code != SNOWFLAKE_SUCCESS_CODE {
1121 anyhow::bail!(
1122 "snowflake sql api returned error code={code} message={}",
1123 response.message.as_deref().unwrap_or("<no message>")
1124 );
1125 }
1126 }
1127
1128 let iter = response
1129 .data
1130 .iter()
1131 .enumerate()
1132 .filter_map(move |(row_index, data_row)| {
1133 let get_string = |name: &str| -> Option<String> {
1134 let index = column_indices.get(name).copied()?;
1135 match data_row.get(index)? {
1136 JsonValue::String(s) => Some(s.clone()),
1137 JsonValue::Null => None,
1138 other => Some(other.to_string()),
1139 }
1140 };
1141
1142 let get_json = |name: &str| -> Option<JsonValue> {
1143 let index = column_indices.get(name).copied()?;
1144 let value = data_row.get(index)?;
1145 if value.is_null() {
1146 return None;
1147 }
1148 match value {
1149 JsonValue::String(s) => serde_json::from_str(s).ok(),
1150 other => Some(other.clone()),
1151 }
1152 };
1153
1154 let request_id_str = get_string("request_id");
1155 let device_id = get_string("device_id");
1156 let time = get_string("time");
1157 let input_json = get_json("input");
1158 let input: Option<ZetaPromptInput> =
1159 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1160 let zed_version = get_string("zed_version");
1161
1162 match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
1163 (Some(request_id), Some(device_id), Some(time), Some(input)) => {
1164 Some(build_example_from_snowflake(
1165 request_id,
1166 device_id,
1167 time,
1168 input,
1169 vec!["requested".to_string()],
1170 None,
1171 zed_version,
1172 ))
1173 }
1174 _ => {
1175 log::warn!(
1176 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
1177 request_id_str.is_some(),
1178 device_id.is_some(),
1179 time.is_some(),
1180 input_json.is_some(),
1181 );
1182 None
1183 }
1184 }
1185 });
1186
1187 Ok(iter)
1188}
1189
1190fn rejected_examples_from_response<'a>(
1191 response: &'a SnowflakeStatementResponse,
1192 column_indices: &'a std::collections::HashMap<String, usize>,
1193) -> Result<impl Iterator<Item = Example> + 'a> {
1194 if let Some(code) = &response.code {
1195 if code != SNOWFLAKE_SUCCESS_CODE {
1196 anyhow::bail!(
1197 "snowflake sql api returned error code={code} message={}",
1198 response.message.as_deref().unwrap_or("<no message>")
1199 );
1200 }
1201 }
1202
1203 let iter = response
1204 .data
1205 .iter()
1206 .enumerate()
1207 .filter_map(move |(row_index, data_row)| {
1208 let get_string = |name: &str| -> Option<String> {
1209 let index = column_indices.get(name).copied()?;
1210 match data_row.get(index)? {
1211 JsonValue::String(s) => Some(s.clone()),
1212 JsonValue::Null => None,
1213 other => Some(other.to_string()),
1214 }
1215 };
1216
1217 let get_json = |name: &str| -> Option<JsonValue> {
1218 let index = column_indices.get(name).copied()?;
1219 let value = data_row.get(index)?;
1220 if value.is_null() {
1221 return None;
1222 }
1223 match value {
1224 JsonValue::String(s) => serde_json::from_str(s).ok(),
1225 other => Some(other.clone()),
1226 }
1227 };
1228
1229 let get_bool = |name: &str| -> Option<bool> {
1230 let index = column_indices.get(name).copied()?;
1231 match data_row.get(index)? {
1232 JsonValue::Bool(b) => Some(*b),
1233 JsonValue::String(s) => s.parse().ok(),
1234 _ => None,
1235 }
1236 };
1237
1238 let request_id_str = get_string("request_id");
1239 let device_id = get_string("device_id");
1240 let time = get_string("time");
1241 let input_json = get_json("input");
1242 let input: Option<ZetaPromptInput> =
1243 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1244 let output = get_string("output");
1245 let was_shown = get_bool("was_shown");
1246 let reason = get_string("reason");
1247 let zed_version = get_string("zed_version");
1248
1249 match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1250 (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1251 Some(build_rejected_example(
1252 request_id,
1253 device_id,
1254 time,
1255 input,
1256 output,
1257 was_shown,
1258 reason,
1259 zed_version,
1260 ))
1261 }
1262 _ => {
1263 log::warn!(
1264 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1265 request_id_str.is_some(),
1266 device_id.is_some(),
1267 time.is_some(),
1268 input_json.is_some(),
1269 output.is_some(),
1270 was_shown.is_some(),
1271 reason.is_some()
1272 );
1273 None
1274 }
1275 }
1276 });
1277
1278 Ok(iter)
1279}
1280
1281fn build_rejected_example(
1282 request_id: String,
1283 device_id: String,
1284 time: String,
1285 input: ZetaPromptInput,
1286 output: String,
1287 was_shown: bool,
1288 reason: String,
1289 zed_version: Option<String>,
1290) -> Example {
1291 let rejected_patch = build_output_patch(
1292 &input.cursor_path,
1293 input.cursor_excerpt.as_ref(),
1294 &input.editable_range_in_excerpt,
1295 &output,
1296 );
1297 let mut example = build_example_from_snowflake(
1298 request_id,
1299 device_id,
1300 time,
1301 input,
1302 vec![format!("rejection:{}", reason.to_lowercase())],
1303 Some(RejectionInfo { reason, was_shown }),
1304 zed_version,
1305 );
1306 example.spec.rejected_patch = Some(rejected_patch);
1307 example
1308}
1309
1310struct RejectionInfo {
1311 reason: String,
1312 was_shown: bool,
1313}
1314
1315fn build_example_from_snowflake(
1316 request_id: String,
1317 device_id: String,
1318 time: String,
1319 input: ZetaPromptInput,
1320 tags: Vec<String>,
1321 rejection: Option<RejectionInfo>,
1322 zed_version: Option<String>,
1323) -> Example {
1324 let events: Vec<CapturedEvent> = input
1325 .events
1326 .iter()
1327 .map(|event| match event.as_ref() {
1328 zeta_prompt::Event::BufferChange {
1329 path,
1330 old_path,
1331 diff,
1332 predicted,
1333 in_open_source_repo,
1334 } => CapturedEvent {
1335 path: path.clone(),
1336 old_path: old_path.clone(),
1337 diff: diff.clone(),
1338 predicted: *predicted,
1339 in_open_source_repo: *in_open_source_repo,
1340 },
1341 })
1342 .collect();
1343
1344 let related_files: Vec<CapturedRelatedFile> = input
1345 .related_files
1346 .iter()
1347 .map(|rf| CapturedRelatedFile {
1348 path: rf.path.clone(),
1349 max_row: rf.max_row,
1350 excerpts: rf
1351 .excerpts
1352 .iter()
1353 .map(|e| CapturedRelatedExcerpt {
1354 row_range: e.row_range.clone(),
1355 text: e.text.to_string(),
1356 })
1357 .collect(),
1358 })
1359 .collect();
1360
1361 let cursor_excerpt = input.cursor_excerpt.as_ref();
1362 let cursor_offset = input.cursor_offset_in_excerpt;
1363
1364 let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset);
1365
1366 let mut edit_history = String::new();
1367 for event in &input.events {
1368 zeta_prompt::write_event(&mut edit_history, event);
1369 edit_history.push('\n');
1370 }
1371
1372 let (rejection_reason, was_shown) = match &rejection {
1373 Some(r) => (r.reason.clone(), r.was_shown),
1374 None => (String::new(), false),
1375 };
1376
1377 let spec = ExampleSpec {
1378 name: request_id.clone(),
1379 repository_url: String::new(),
1380 revision: String::new(),
1381 tags,
1382 reasoning: None,
1383 uncommitted_diff: String::new(),
1384 cursor_path: input.cursor_path.clone(),
1385 cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1386 edit_history,
1387 expected_patches: Vec::new(),
1388 rejected_patch: None,
1389 captured_prompt_input: Some(CapturedPromptInput {
1390 cursor_file_content: cursor_excerpt.to_string(),
1391 cursor_offset,
1392 cursor_row,
1393 cursor_column,
1394 excerpt_start_row: None,
1395 events,
1396 related_files,
1397 in_open_source_repo: input.in_open_source_repo,
1398 zed_version,
1399 }),
1400 telemetry: Some(TelemetrySource {
1401 request_id,
1402 device_id,
1403 time,
1404 rejection_reason,
1405 was_shown,
1406 }),
1407 human_feedback: Vec::new(),
1408 rating: None,
1409 };
1410
1411 Example {
1412 spec,
1413 prompt_inputs: None,
1414 prompt: None,
1415 predictions: Vec::new(),
1416 score: Vec::new(),
1417 qa: Vec::new(),
1418 state: None,
1419 }
1420}
1421
1422fn compute_row_column(text: &str, offset: usize) -> (u32, u32) {
1423 let mut row = 0u32;
1424 let mut last_newline_offset = 0;
1425 for (i, c) in text.char_indices() {
1426 if i >= offset {
1427 break;
1428 }
1429 if c == '\n' {
1430 row += 1;
1431 last_newline_offset = i + 1;
1432 }
1433 }
1434 let column = (offset - last_newline_offset) as u32;
1435 (row, column)
1436}
1437
1438fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1439 let before = &excerpt[..cursor_offset.min(excerpt.len())];
1440 let after = &excerpt[cursor_offset.min(excerpt.len())..];
1441 format!("{}[CURSOR_POSITION]{}", before, after)
1442}
1443
1444fn build_output_patch(
1445 cursor_path: &std::path::Path,
1446 cursor_excerpt: &str,
1447 editable_range: &std::ops::Range<usize>,
1448 model_output: &str,
1449) -> String {
1450 let old_text = &cursor_excerpt[editable_range.clone()];
1451
1452 let editable_start_row = cursor_excerpt[..editable_range.start]
1453 .chars()
1454 .filter(|&c| c == '\n')
1455 .count() as u32;
1456
1457 let diff_body = language::unified_diff_with_offsets(
1458 old_text,
1459 model_output,
1460 editable_start_row,
1461 editable_start_row,
1462 );
1463
1464 let mut patch = String::new();
1465 writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1466 writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1467 patch.push_str(&diff_body);
1468 patch
1469}
1470
1471pub(crate) fn get_column_indices(
1472 meta: &Option<SnowflakeResultSetMetaData>,
1473 names: &[&str],
1474) -> std::collections::HashMap<String, usize> {
1475 let mut indices = std::collections::HashMap::new();
1476 if let Some(meta) = meta {
1477 for (index, col) in meta.row_type.iter().enumerate() {
1478 for &name in names {
1479 if col.name.eq_ignore_ascii_case(name) {
1480 indices.insert(name.to_string(), index);
1481 }
1482 }
1483 }
1484 }
1485 indices
1486}