1use anyhow::{Context as _, Result};
2use flate2::read::GzDecoder;
3use gpui::BackgroundExecutor;
4use http_client::{AsyncBody, HttpClient, Method, Request};
5use indoc::indoc;
6use serde::Deserialize;
7use serde_json::{Value as JsonValue, json};
8use std::io::Read;
9use std::sync::Arc;
10use std::time::Duration;
11use telemetry_events::EditPredictionRating;
12
13use zeta_prompt::ZetaPromptInput;
14
15use crate::example::Example;
16use crate::progress::{InfoStyle, Progress, Step};
17use crate::sync_deployments::EDIT_PREDICTION_DEPLOYMENT_EVENT;
18use edit_prediction::example_spec::{
19 CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec,
20 TelemetrySource,
21};
22use std::fmt::Write as _;
23
24pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
25pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
26const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
27const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
28const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
29const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
30
31/// Minimum Zed version for filtering captured examples.
32/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
33/// where `zed_version >= 0.224.1`.
34#[derive(Clone, Copy, Debug)]
35pub struct MinCaptureVersion {
36 pub minor: u32,
37 pub patch: u32,
38}
39
40const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
41pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
42pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
43
44/// Parse an input token of the form `captured-after:{timestamp}`.
45pub fn parse_captured_after_input(input: &str) -> Option<&str> {
46 input.strip_prefix("captured-after:")
47}
48
49/// Parse an input token of the form `rejected-after:{timestamp}`.
50pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
51 input.strip_prefix("rejected-after:")
52}
53
54/// Parse an input token of the form `requested-after:{timestamp}`.
55pub fn parse_requested_after_input(input: &str) -> Option<&str> {
56 input.strip_prefix("requested-after:")
57}
58
59/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
60/// or `rated-negative-after:{timestamp}`.
61/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
62pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
63 if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
64 Some((timestamp, Some(EditPredictionRating::Positive)))
65 } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
66 Some((timestamp, Some(EditPredictionRating::Negative)))
67 } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
68 Some((timestamp, None))
69 } else {
70 None
71 }
72}
73
74pub async fn fetch_captured_examples_after(
75 http_client: Arc<dyn HttpClient>,
76 after_timestamps: &[String],
77 max_rows_per_timestamp: usize,
78 offset: usize,
79 background_executor: BackgroundExecutor,
80 _min_capture_version: Option<MinCaptureVersion>,
81) -> Result<Vec<Example>> {
82 if after_timestamps.is_empty() {
83 return Ok(Vec::new());
84 }
85
86 let progress = Progress::global();
87
88 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
89 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
90 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
91 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
92 )?;
93 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
94
95 let mut all_examples = Vec::new();
96
97 for after_date in after_timestamps.iter() {
98 let step_progress_name = format!(">{after_date}");
99 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
100 step_progress.set_substatus("querying");
101
102 let statement = indoc! {r#"
103 SELECT
104 event_properties:example AS example
105 FROM events
106 WHERE event_type = ?
107 AND time > TRY_TO_TIMESTAMP_NTZ(?)
108 ORDER BY time ASC
109 LIMIT ?
110 OFFSET ?
111 "#};
112
113 let request = json!({
114 "statement": statement,
115 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
116 "database": "EVENTS",
117 "schema": "PUBLIC",
118 "warehouse": "DBT",
119 "role": role,
120 "bindings": {
121 "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
122 "2": { "type": "TEXT", "value": after_date },
123 "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
124 "4": { "type": "FIXED", "value": offset.to_string() }
125 }
126 });
127
128 let response = run_sql_with_polling(
129 http_client.clone(),
130 &base_url,
131 &token,
132 &request,
133 &step_progress,
134 background_executor.clone(),
135 )
136 .await?;
137
138 let total_rows = response
139 .result_set_meta_data
140 .as_ref()
141 .and_then(|m| m.num_rows)
142 .unwrap_or(response.data.len() as i64);
143
144 let num_partitions = response
145 .result_set_meta_data
146 .as_ref()
147 .map(|m| m.partition_info.len())
148 .unwrap_or(1)
149 .max(1);
150
151 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
152 step_progress.set_substatus("parsing");
153
154 let example_index = response
155 .result_set_meta_data
156 .as_ref()
157 .and_then(|m| {
158 m.row_type.iter().enumerate().find_map(|(index, col)| {
159 if col.name.eq_ignore_ascii_case("example") {
160 Some(index)
161 } else {
162 None
163 }
164 })
165 })
166 .unwrap_or(0);
167
168 all_examples.extend(examples_from_response(&response, example_index)?);
169
170 if num_partitions > 1 {
171 let statement_handle = response
172 .statement_handle
173 .as_ref()
174 .context("response has multiple partitions but no statementHandle")?;
175
176 for partition in 1..num_partitions {
177 step_progress.set_substatus(format!(
178 "fetching partition {}/{}",
179 partition + 1,
180 num_partitions
181 ));
182
183 let partition_response = fetch_partition(
184 http_client.clone(),
185 &base_url,
186 &token,
187 statement_handle,
188 partition,
189 )
190 .await?;
191
192 all_examples.extend(examples_from_response(&partition_response, example_index)?);
193 }
194 }
195
196 step_progress.set_substatus("done");
197 }
198
199 Ok(all_examples)
200}
201
202#[derive(Debug, Clone, Deserialize)]
203#[serde(rename_all = "camelCase")]
204pub(crate) struct SnowflakeStatementResponse {
205 #[serde(default)]
206 pub(crate) data: Vec<Vec<JsonValue>>,
207 #[serde(default)]
208 pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
209 #[serde(default)]
210 pub(crate) code: Option<String>,
211 #[serde(default)]
212 pub(crate) message: Option<String>,
213 #[serde(default)]
214 pub(crate) statement_handle: Option<String>,
215}
216
217#[derive(Debug, Clone, Deserialize)]
218#[serde(rename_all = "camelCase")]
219pub(crate) struct SnowflakeResultSetMetaData {
220 #[serde(default, rename = "rowType")]
221 row_type: Vec<SnowflakeColumnMeta>,
222 #[serde(default)]
223 num_rows: Option<i64>,
224 #[serde(default)]
225 partition_info: Vec<SnowflakePartitionInfo>,
226}
227
228#[derive(Debug, Clone, Deserialize)]
229#[serde(rename_all = "camelCase")]
230struct SnowflakePartitionInfo {}
231
232#[derive(Debug, Clone, Deserialize)]
233struct SnowflakeColumnMeta {
234 #[serde(default)]
235 name: String,
236}
237
238fn examples_from_response(
239 response: &SnowflakeStatementResponse,
240 example_index: usize,
241) -> Result<impl Iterator<Item = Example> + '_> {
242 if let Some(code) = &response.code {
243 if code != SNOWFLAKE_SUCCESS_CODE {
244 anyhow::bail!(
245 "snowflake sql api returned error code={code} message={}",
246 response.message.as_deref().unwrap_or("<no message>")
247 );
248 }
249 }
250
251 let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
252 let Some(example_value) = data_row.get(example_index) else {
253 return None;
254 };
255 if example_value.is_null() {
256 return None;
257 }
258
259 let parse_result = match example_value {
260 JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
261 _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
262 };
263
264 match parse_result {
265 Ok(spec) => Some(Example {
266 spec,
267 prompt_inputs: None,
268 prompt: None,
269 predictions: Vec::new(),
270 score: Vec::new(),
271 qa: Vec::new(),
272 state: None,
273 }),
274 Err(error) => {
275 let raw_json = serde_json::to_string_pretty(example_value)
276 .unwrap_or_else(|_| "<failed to serialize json>".to_string());
277 log::error!(
278 "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
279 );
280 None
281 }
282 }
283 });
284
285 Ok(iter)
286}
287
288async fn run_sql_with_polling(
289 http_client: Arc<dyn HttpClient>,
290 base_url: &str,
291 token: &str,
292 request: &serde_json::Value,
293 step_progress: &crate::progress::StepProgress,
294 background_executor: BackgroundExecutor,
295) -> Result<SnowflakeStatementResponse> {
296 let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
297
298 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
299 let statement_handle = response
300 .statement_handle
301 .as_ref()
302 .context("async query response missing statementHandle")?
303 .clone();
304
305 for attempt in 1..=MAX_POLL_ATTEMPTS {
306 step_progress.set_substatus(format!("polling ({attempt})"));
307
308 background_executor.timer(POLL_INTERVAL).await;
309
310 response =
311 fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
312
313 if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
314 break;
315 }
316 }
317
318 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
319 anyhow::bail!(
320 "query still running after {} poll attempts ({} seconds)",
321 MAX_POLL_ATTEMPTS,
322 MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
323 );
324 }
325 }
326
327 Ok(response)
328}
329
330pub(crate) async fn fetch_partition(
331 http_client: Arc<dyn HttpClient>,
332 base_url: &str,
333 token: &str,
334 statement_handle: &str,
335 partition: usize,
336) -> Result<SnowflakeStatementResponse> {
337 let url = format!(
338 "{}/api/v2/statements/{}?partition={}",
339 base_url.trim_end_matches('/'),
340 statement_handle,
341 partition
342 );
343
344 let http_request = Request::builder()
345 .method(Method::GET)
346 .uri(url.as_str())
347 .header("Authorization", format!("Bearer {token}"))
348 .header(
349 "X-Snowflake-Authorization-Token-Type",
350 "PROGRAMMATIC_ACCESS_TOKEN",
351 )
352 .header("Accept", "application/json")
353 .header("Accept-Encoding", "gzip")
354 .header("User-Agent", "edit_prediction_cli")
355 .body(AsyncBody::empty())?;
356
357 let response = http_client
358 .send(http_request)
359 .await
360 .context("failed to send partition request to Snowflake SQL API")?;
361
362 let status = response.status();
363 let content_encoding = response
364 .headers()
365 .get("content-encoding")
366 .and_then(|v| v.to_str().ok())
367 .map(|s| s.to_lowercase());
368
369 let body_bytes = {
370 use futures::AsyncReadExt as _;
371
372 let mut body = response.into_body();
373 let mut bytes = Vec::new();
374 body.read_to_end(&mut bytes)
375 .await
376 .context("failed to read Snowflake SQL API partition response body")?;
377 bytes
378 };
379
380 let body_bytes = if content_encoding.as_deref() == Some("gzip") {
381 let mut decoder = GzDecoder::new(&body_bytes[..]);
382 let mut decompressed = Vec::new();
383 decoder
384 .read_to_end(&mut decompressed)
385 .context("failed to decompress gzip response")?;
386 decompressed
387 } else {
388 body_bytes
389 };
390
391 if !status.is_success() && status.as_u16() != 202 {
392 let body_text = String::from_utf8_lossy(&body_bytes);
393 anyhow::bail!(
394 "snowflake sql api partition request http {}: {}",
395 status.as_u16(),
396 body_text
397 );
398 }
399
400 if body_bytes.is_empty() {
401 anyhow::bail!(
402 "snowflake sql api partition {} returned empty response body (http {})",
403 partition,
404 status.as_u16()
405 );
406 }
407
408 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
409 let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
410 format!(
411 "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
412 partition,
413 status.as_u16(),
414 body_preview
415 )
416 })
417}
418
419pub(crate) async fn run_sql(
420 http_client: Arc<dyn HttpClient>,
421 base_url: &str,
422 token: &str,
423 request: &serde_json::Value,
424) -> Result<SnowflakeStatementResponse> {
425 let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
426
427 let request_body =
428 serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
429
430 let http_request = Request::builder()
431 .method(Method::POST)
432 .uri(url.as_str())
433 .header("Authorization", format!("Bearer {token}"))
434 .header(
435 "X-Snowflake-Authorization-Token-Type",
436 "PROGRAMMATIC_ACCESS_TOKEN",
437 )
438 .header("Content-Type", "application/json")
439 .header("Accept", "application/json")
440 .header("User-Agent", "edit_prediction_cli")
441 .body(AsyncBody::from(request_body.clone()))?;
442
443 let response = http_client
444 .send(http_request)
445 .await
446 .context("failed to send request to Snowflake SQL API")?;
447
448 let status = response.status();
449 let body_bytes = {
450 use futures::AsyncReadExt as _;
451
452 let mut body = response.into_body();
453 let mut bytes = Vec::new();
454 body.read_to_end(&mut bytes)
455 .await
456 .context("failed to read Snowflake SQL API response body")?;
457 bytes
458 };
459
460 if !status.is_success() && status.as_u16() != 202 {
461 let body_text = String::from_utf8_lossy(&body_bytes);
462 anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
463 }
464
465 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
466 .context("failed to parse Snowflake SQL API response JSON")
467}
468
469pub async fn fetch_rejected_examples_after(
470 http_client: Arc<dyn HttpClient>,
471 after_timestamps: &[String],
472 max_rows_per_timestamp: usize,
473 offset: usize,
474 background_executor: BackgroundExecutor,
475 min_capture_version: Option<MinCaptureVersion>,
476) -> Result<Vec<Example>> {
477 if after_timestamps.is_empty() {
478 return Ok(Vec::new());
479 }
480
481 let progress = Progress::global();
482
483 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
484 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
485 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
486 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
487 )?;
488 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
489
490 let mut all_examples = Vec::new();
491
492 for after_date in after_timestamps.iter() {
493 let step_progress_name = format!("rejected>{after_date}");
494 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
495 step_progress.set_substatus("querying");
496
497 // Join rejected events with their corresponding request events to get the full context.
498 // We filter for V3 sampling data which contains the structured input we need.
499 // We also filter for predictions that were actually shown to the user (was_shown = true)
500 // to focus on explicit user rejections rather than implicit cancellations.
501 let statement = indoc! {r#"
502 SELECT
503 req.event_properties:request_id::string AS request_id,
504 req.device_id::string AS device_id,
505 req.time::string AS time,
506 req.event_properties:input AS input,
507 req.event_properties:prompt::string AS prompt,
508 req.event_properties:output::string AS output,
509 rej.event_properties:was_shown::boolean AS was_shown,
510 rej.event_properties:reason::string AS reason,
511 req.event_properties:zed_version::string AS zed_version
512 FROM events req
513 INNER JOIN events rej
514 ON req.event_properties:request_id = rej.event_properties:request_id
515 WHERE req.event_type = ?
516 AND rej.event_type = ?
517 AND req.event_properties:version = 'V3'
518 AND rej.event_properties:was_shown = true
519 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
520 AND (? IS NULL OR (
521 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
522 OR (
523 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
524 AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
525 )
526 ))
527 ORDER BY req.time ASC
528 LIMIT ?
529 OFFSET ?
530 "#};
531
532 let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
533 let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
534 let min_minor_str_ref = min_minor_str.as_deref();
535 let min_patch_str_ref = min_patch_str.as_deref();
536 let request = json!({
537 "statement": statement,
538 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
539 "database": "EVENTS",
540 "schema": "PUBLIC",
541 "warehouse": "DBT",
542 "role": role,
543 "bindings": {
544 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
545 "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
546 "3": { "type": "TEXT", "value": after_date },
547 "4": { "type": "FIXED", "value": min_minor_str_ref },
548 "5": { "type": "FIXED", "value": min_minor_str_ref },
549 "6": { "type": "FIXED", "value": min_minor_str_ref },
550 "7": { "type": "FIXED", "value": min_patch_str_ref },
551 "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
552 "9": { "type": "FIXED", "value": offset.to_string() }
553 }
554 });
555
556 let response = run_sql_with_polling(
557 http_client.clone(),
558 &base_url,
559 &token,
560 &request,
561 &step_progress,
562 background_executor.clone(),
563 )
564 .await?;
565
566 let total_rows = response
567 .result_set_meta_data
568 .as_ref()
569 .and_then(|m| m.num_rows)
570 .unwrap_or(response.data.len() as i64);
571
572 let num_partitions = response
573 .result_set_meta_data
574 .as_ref()
575 .map(|m| m.partition_info.len())
576 .unwrap_or(1)
577 .max(1);
578
579 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
580 step_progress.set_substatus("parsing");
581
582 let column_indices = get_column_indices(
583 &response.result_set_meta_data,
584 &[
585 "request_id",
586 "device_id",
587 "time",
588 "input",
589 "prompt",
590 "output",
591 "was_shown",
592 "reason",
593 "zed_version",
594 ],
595 );
596
597 all_examples.extend(rejected_examples_from_response(&response, &column_indices)?);
598
599 if num_partitions > 1 {
600 let statement_handle = response
601 .statement_handle
602 .as_ref()
603 .context("response has multiple partitions but no statementHandle")?;
604
605 for partition in 1..num_partitions {
606 step_progress.set_substatus(format!(
607 "fetching partition {}/{}",
608 partition + 1,
609 num_partitions
610 ));
611
612 let partition_response = fetch_partition(
613 http_client.clone(),
614 &base_url,
615 &token,
616 statement_handle,
617 partition,
618 )
619 .await?;
620
621 all_examples.extend(rejected_examples_from_response(
622 &partition_response,
623 &column_indices,
624 )?);
625 }
626 }
627
628 step_progress.set_substatus("done");
629 }
630
631 Ok(all_examples)
632}
633
634pub async fn fetch_requested_examples_after(
635 http_client: Arc<dyn HttpClient>,
636 after_timestamps: &[String],
637 max_rows_per_timestamp: usize,
638 offset: usize,
639 background_executor: BackgroundExecutor,
640 min_capture_version: Option<MinCaptureVersion>,
641) -> Result<Vec<Example>> {
642 if after_timestamps.is_empty() {
643 return Ok(Vec::new());
644 }
645
646 let progress = Progress::global();
647
648 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
649 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
650 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
651 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
652 )?;
653 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
654
655 let mut all_examples = Vec::new();
656
657 for after_date in after_timestamps.iter() {
658 let step_progress_name = format!("requested>{after_date}");
659 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
660 step_progress.set_substatus("querying");
661
662 let statement = indoc! {r#"
663 SELECT
664 req.event_properties:request_id::string AS request_id,
665 req.device_id::string AS device_id,
666 req.time::string AS time,
667 req.event_properties:input AS input,
668 req.event_properties:zed_version::string AS zed_version
669 FROM events req
670 WHERE req.event_type = ?
671 AND req.event_properties:version = 'V3'
672 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
673 AND (? IS NULL OR (
674 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
675 OR (
676 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
677 AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
678 )
679 ))
680 ORDER BY req.time ASC
681 LIMIT ?
682 OFFSET ?
683 "#};
684
685 let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
686 let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
687 let min_minor_str_ref = min_minor_str.as_deref();
688 let min_patch_str_ref = min_patch_str.as_deref();
689 let request = json!({
690 "statement": statement,
691 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
692 "database": "EVENTS",
693 "schema": "PUBLIC",
694 "warehouse": "DBT",
695 "role": role,
696 "bindings": {
697 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
698 "2": { "type": "TEXT", "value": after_date },
699 "3": { "type": "FIXED", "value": min_minor_str_ref },
700 "4": { "type": "FIXED", "value": min_minor_str_ref },
701 "5": { "type": "FIXED", "value": min_minor_str_ref },
702 "6": { "type": "FIXED", "value": min_patch_str_ref },
703 "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
704 "8": { "type": "FIXED", "value": offset.to_string() }
705 }
706 });
707
708 let response = run_sql_with_polling(
709 http_client.clone(),
710 &base_url,
711 &token,
712 &request,
713 &step_progress,
714 background_executor.clone(),
715 )
716 .await?;
717
718 let total_rows = response
719 .result_set_meta_data
720 .as_ref()
721 .and_then(|m| m.num_rows)
722 .unwrap_or(response.data.len() as i64);
723
724 let num_partitions = response
725 .result_set_meta_data
726 .as_ref()
727 .map(|m| m.partition_info.len())
728 .unwrap_or(1)
729 .max(1);
730
731 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
732 step_progress.set_substatus("parsing");
733
734 let column_indices = get_column_indices(
735 &response.result_set_meta_data,
736 &["request_id", "device_id", "time", "input", "zed_version"],
737 );
738
739 all_examples.extend(requested_examples_from_response(
740 &response,
741 &column_indices,
742 )?);
743
744 if num_partitions > 1 {
745 let statement_handle = response
746 .statement_handle
747 .as_ref()
748 .context("response has multiple partitions but no statementHandle")?;
749
750 for partition in 1..num_partitions {
751 step_progress.set_substatus(format!(
752 "fetching partition {}/{}",
753 partition + 1,
754 num_partitions
755 ));
756
757 let partition_response = fetch_partition(
758 http_client.clone(),
759 &base_url,
760 &token,
761 statement_handle,
762 partition,
763 )
764 .await?;
765
766 all_examples.extend(requested_examples_from_response(
767 &partition_response,
768 &column_indices,
769 )?);
770 }
771 }
772
773 step_progress.set_substatus("done");
774 }
775
776 Ok(all_examples)
777}
778
779pub async fn fetch_rated_examples_after(
780 http_client: Arc<dyn HttpClient>,
781 inputs: &[(String, Option<EditPredictionRating>)],
782 max_rows_per_timestamp: usize,
783 offset: usize,
784 background_executor: BackgroundExecutor,
785 min_capture_version: Option<MinCaptureVersion>,
786) -> Result<Vec<Example>> {
787 if inputs.is_empty() {
788 return Ok(Vec::new());
789 }
790
791 let progress = Progress::global();
792
793 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
794 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
795 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
796 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
797 )?;
798 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
799
800 let mut all_examples = Vec::new();
801
802 for (after_date, rating_filter) in inputs.iter() {
803 let filter_label = match rating_filter {
804 None => "",
805 Some(EditPredictionRating::Positive) => ":positive",
806 Some(EditPredictionRating::Negative) => ":negative",
807 };
808 let step_progress_name = format!("rated{filter_label}>{after_date}");
809 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
810 step_progress.set_substatus("querying");
811
812 let rating_value = rating_filter.as_ref().map(|r| match r {
813 EditPredictionRating::Positive => "Positive",
814 EditPredictionRating::Negative => "Negative",
815 });
816
817 let statement = indoc! {r#"
818 SELECT
819 rated.event_properties:request_id::string AS request_id,
820 rated.event_properties:inputs AS inputs,
821 rated.event_properties:output::string AS output,
822 rated.event_properties:rating::string AS rating,
823 rated.event_properties:feedback::string AS feedback,
824 rated.device_id::string AS device_id,
825 rated.time::string AS time,
826 deploy.event_properties:experiment_name::string AS experiment_name,
827 deploy.event_properties:environment::string AS environment,
828 rated.event_properties:zed_version::string AS zed_version
829 FROM events rated
830 LEFT JOIN events req
831 ON rated.event_properties:request_id::string = req.event_properties:request_id::string
832 AND req.event_type = ?
833 LEFT JOIN events deploy
834 ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string
835 AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string
836 AND deploy.event_type = ?
837 WHERE rated.event_type = ?
838 AND (? IS NULL OR rated.event_properties:rating::string = ?)
839 AND rated.time > TRY_TO_TIMESTAMP_NTZ(?)
840 AND rated.event_properties:inputs IS NOT NULL
841 AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
842 AND rated.event_properties:output IS NOT NULL
843 AND (? IS NULL OR (
844 TRY_CAST(SPLIT_PART(rated.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
845 OR (
846 TRY_CAST(SPLIT_PART(rated.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
847 AND TRY_CAST(SPLIT_PART(SPLIT_PART(rated.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
848 )
849 ))
850 ORDER BY rated.time ASC
851 LIMIT ?
852 OFFSET ?
853 "#};
854
855 let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
856 let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
857 let min_minor_str_ref = min_minor_str.as_deref();
858 let min_patch_str_ref = min_patch_str.as_deref();
859 let bindings = json!({
860 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
861 "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
862 "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
863 "4": { "type": "TEXT", "value": rating_value },
864 "5": { "type": "TEXT", "value": rating_value },
865 "6": { "type": "TEXT", "value": after_date },
866 "7": { "type": "FIXED", "value": min_minor_str_ref },
867 "8": { "type": "FIXED", "value": min_minor_str_ref },
868 "9": { "type": "FIXED", "value": min_minor_str_ref },
869 "10": { "type": "FIXED", "value": min_patch_str_ref },
870 "11": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
871 "12": { "type": "FIXED", "value": offset.to_string() }
872 });
873
874 let request = json!({
875 "statement": statement,
876 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
877 "database": "EVENTS",
878 "schema": "PUBLIC",
879 "warehouse": "DBT",
880 "role": role,
881 "bindings": bindings
882 });
883
884 let response = run_sql_with_polling(
885 http_client.clone(),
886 &base_url,
887 &token,
888 &request,
889 &step_progress,
890 background_executor.clone(),
891 )
892 .await?;
893
894 let total_rows = response
895 .result_set_meta_data
896 .as_ref()
897 .and_then(|m| m.num_rows)
898 .unwrap_or(response.data.len() as i64);
899
900 let num_partitions = response
901 .result_set_meta_data
902 .as_ref()
903 .map(|m| m.partition_info.len())
904 .unwrap_or(1)
905 .max(1);
906
907 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
908 step_progress.set_substatus("parsing");
909
910 let column_indices = get_column_indices(
911 &response.result_set_meta_data,
912 &[
913 "request_id",
914 "inputs",
915 "output",
916 "rating",
917 "feedback",
918 "device_id",
919 "time",
920 "experiment_name",
921 "environment",
922 "zed_version",
923 ],
924 );
925
926 all_examples.extend(rated_examples_from_response(&response, &column_indices)?);
927
928 if num_partitions > 1 {
929 let statement_handle = response
930 .statement_handle
931 .as_ref()
932 .context("response has multiple partitions but no statementHandle")?;
933
934 for partition in 1..num_partitions {
935 step_progress.set_substatus(format!(
936 "fetching partition {}/{}",
937 partition + 1,
938 num_partitions
939 ));
940
941 let partition_response = fetch_partition(
942 http_client.clone(),
943 &base_url,
944 &token,
945 statement_handle,
946 partition,
947 )
948 .await?;
949
950 all_examples.extend(rated_examples_from_response(
951 &partition_response,
952 &column_indices,
953 )?);
954 }
955 }
956
957 step_progress.set_substatus("done");
958 }
959
960 Ok(all_examples)
961}
962
963fn rated_examples_from_response<'a>(
964 response: &'a SnowflakeStatementResponse,
965 column_indices: &'a std::collections::HashMap<String, usize>,
966) -> Result<impl Iterator<Item = Example> + 'a> {
967 if let Some(code) = &response.code {
968 if code != SNOWFLAKE_SUCCESS_CODE {
969 anyhow::bail!(
970 "snowflake sql api returned error code={code} message={}",
971 response.message.as_deref().unwrap_or("<no message>")
972 );
973 }
974 }
975
976 let iter = response
977 .data
978 .iter()
979 .enumerate()
980 .filter_map(move |(row_index, data_row)| {
981 let get_string = |name: &str| -> Option<String> {
982 let index = column_indices.get(name).copied()?;
983 match data_row.get(index)? {
984 JsonValue::String(s) => Some(s.clone()),
985 JsonValue::Null => None,
986 other => Some(other.to_string()),
987 }
988 };
989
990 let get_json = |name: &str| -> Option<JsonValue> {
991 let index = column_indices.get(name).copied()?;
992 let value = data_row.get(index)?;
993 if value.is_null() {
994 return None;
995 }
996 match value {
997 JsonValue::String(s) => serde_json::from_str(s).ok(),
998 other => Some(other.clone()),
999 }
1000 };
1001
1002 let request_id = get_string("request_id");
1003 let inputs_json = get_json("inputs");
1004 let inputs: Option<ZetaPromptInput> = match &inputs_json {
1005 Some(v) => match serde_json::from_value(v.clone()) {
1006 Ok(parsed) => Some(parsed),
1007 Err(e) => {
1008 log::warn!(
1009 "skipping row {row_index}: failed to parse inputs - {e}",
1010 );
1011 return None;
1012 }
1013 },
1014 None => None,
1015 };
1016 let output = get_string("output");
1017 let rating = get_string("rating");
1018 let feedback = get_string("feedback").unwrap_or_default();
1019 let device_id = get_string("device_id");
1020 let time = get_string("time");
1021 let experiment_name = get_string("experiment_name");
1022 let environment = get_string("environment");
1023 let zed_version = get_string("zed_version");
1024
1025 match (inputs, output.clone(), rating.clone(), device_id.clone(), time.clone()) {
1026 (Some(inputs), Some(output), Some(rating), Some(device_id), Some(time)) => {
1027 Some(build_rated_example(
1028 request_id,
1029 device_id,
1030 time,
1031 inputs,
1032 output,
1033 rating,
1034 feedback,
1035 experiment_name,
1036 environment,
1037 zed_version,
1038 ))
1039 }
1040 _ => {
1041 log::warn!(
1042 "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} device_id={:?} time={:?}",
1043 inputs_json.is_some(),
1044 output.is_some(),
1045 rating.is_some(),
1046 device_id.is_some(),
1047 time.is_some(),
1048 );
1049 None
1050 }
1051 }
1052 });
1053
1054 Ok(iter)
1055}
1056
1057fn build_rated_example(
1058 request_id: Option<String>,
1059 device_id: String,
1060 time: String,
1061 input: ZetaPromptInput,
1062 output: String,
1063 rating: String,
1064 feedback: String,
1065 experiment_name: Option<String>,
1066 environment: Option<String>,
1067 zed_version: Option<String>,
1068) -> Example {
1069 let parsed_rating = if rating == "Positive" {
1070 EditPredictionRating::Positive
1071 } else {
1072 EditPredictionRating::Negative
1073 };
1074 let is_positive = parsed_rating == EditPredictionRating::Positive;
1075 let request_id = request_id.unwrap_or_else(|| format!("rated-{}-{}", device_id, time));
1076
1077 let mut tags = Vec::with_capacity(3);
1078 tags.push(if is_positive {
1079 "rated:positive".to_string()
1080 } else {
1081 "rated:negative".to_string()
1082 });
1083 if let Some(experiment) = experiment_name {
1084 tags.push(format!("experiment:{experiment}"));
1085 }
1086 if let Some(env) = environment {
1087 tags.push(format!("environment:{env}"));
1088 }
1089
1090 let mut example =
1091 build_example_from_snowflake(request_id, device_id, time, input, tags, None, zed_version);
1092
1093 example.spec.rating = Some(parsed_rating);
1094
1095 if !feedback.is_empty() {
1096 example
1097 .spec
1098 .human_feedback
1099 .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
1100 }
1101
1102 if is_positive {
1103 example.spec.expected_patches = vec![output];
1104 } else {
1105 example.spec.rejected_patch = Some(output);
1106 }
1107
1108 example
1109}
1110
1111fn requested_examples_from_response<'a>(
1112 response: &'a SnowflakeStatementResponse,
1113 column_indices: &'a std::collections::HashMap<String, usize>,
1114) -> Result<impl Iterator<Item = Example> + 'a> {
1115 if let Some(code) = &response.code {
1116 if code != SNOWFLAKE_SUCCESS_CODE {
1117 anyhow::bail!(
1118 "snowflake sql api returned error code={code} message={}",
1119 response.message.as_deref().unwrap_or("<no message>")
1120 );
1121 }
1122 }
1123
1124 let iter = response
1125 .data
1126 .iter()
1127 .enumerate()
1128 .filter_map(move |(row_index, data_row)| {
1129 let get_string = |name: &str| -> Option<String> {
1130 let index = column_indices.get(name).copied()?;
1131 match data_row.get(index)? {
1132 JsonValue::String(s) => Some(s.clone()),
1133 JsonValue::Null => None,
1134 other => Some(other.to_string()),
1135 }
1136 };
1137
1138 let get_json = |name: &str| -> Option<JsonValue> {
1139 let index = column_indices.get(name).copied()?;
1140 let value = data_row.get(index)?;
1141 if value.is_null() {
1142 return None;
1143 }
1144 match value {
1145 JsonValue::String(s) => serde_json::from_str(s).ok(),
1146 other => Some(other.clone()),
1147 }
1148 };
1149
1150 let request_id_str = get_string("request_id");
1151 let device_id = get_string("device_id");
1152 let time = get_string("time");
1153 let input_json = get_json("input");
1154 let input: Option<ZetaPromptInput> =
1155 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1156 let zed_version = get_string("zed_version");
1157
1158 match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
1159 (Some(request_id), Some(device_id), Some(time), Some(input)) => {
1160 Some(build_example_from_snowflake(
1161 request_id,
1162 device_id,
1163 time,
1164 input,
1165 vec!["requested".to_string()],
1166 None,
1167 zed_version,
1168 ))
1169 }
1170 _ => {
1171 log::warn!(
1172 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
1173 request_id_str.is_some(),
1174 device_id.is_some(),
1175 time.is_some(),
1176 input_json.is_some(),
1177 );
1178 None
1179 }
1180 }
1181 });
1182
1183 Ok(iter)
1184}
1185
1186fn rejected_examples_from_response<'a>(
1187 response: &'a SnowflakeStatementResponse,
1188 column_indices: &'a std::collections::HashMap<String, usize>,
1189) -> Result<impl Iterator<Item = Example> + 'a> {
1190 if let Some(code) = &response.code {
1191 if code != SNOWFLAKE_SUCCESS_CODE {
1192 anyhow::bail!(
1193 "snowflake sql api returned error code={code} message={}",
1194 response.message.as_deref().unwrap_or("<no message>")
1195 );
1196 }
1197 }
1198
1199 let iter = response
1200 .data
1201 .iter()
1202 .enumerate()
1203 .filter_map(move |(row_index, data_row)| {
1204 let get_string = |name: &str| -> Option<String> {
1205 let index = column_indices.get(name).copied()?;
1206 match data_row.get(index)? {
1207 JsonValue::String(s) => Some(s.clone()),
1208 JsonValue::Null => None,
1209 other => Some(other.to_string()),
1210 }
1211 };
1212
1213 let get_json = |name: &str| -> Option<JsonValue> {
1214 let index = column_indices.get(name).copied()?;
1215 let value = data_row.get(index)?;
1216 if value.is_null() {
1217 return None;
1218 }
1219 match value {
1220 JsonValue::String(s) => serde_json::from_str(s).ok(),
1221 other => Some(other.clone()),
1222 }
1223 };
1224
1225 let get_bool = |name: &str| -> Option<bool> {
1226 let index = column_indices.get(name).copied()?;
1227 match data_row.get(index)? {
1228 JsonValue::Bool(b) => Some(*b),
1229 JsonValue::String(s) => s.parse().ok(),
1230 _ => None,
1231 }
1232 };
1233
1234 let request_id_str = get_string("request_id");
1235 let device_id = get_string("device_id");
1236 let time = get_string("time");
1237 let input_json = get_json("input");
1238 let input: Option<ZetaPromptInput> =
1239 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1240 let output = get_string("output");
1241 let was_shown = get_bool("was_shown");
1242 let reason = get_string("reason");
1243 let zed_version = get_string("zed_version");
1244
1245 match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1246 (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1247 Some(build_rejected_example(
1248 request_id,
1249 device_id,
1250 time,
1251 input,
1252 output,
1253 was_shown,
1254 reason,
1255 zed_version,
1256 ))
1257 }
1258 _ => {
1259 log::warn!(
1260 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1261 request_id_str.is_some(),
1262 device_id.is_some(),
1263 time.is_some(),
1264 input_json.is_some(),
1265 output.is_some(),
1266 was_shown.is_some(),
1267 reason.is_some()
1268 );
1269 None
1270 }
1271 }
1272 });
1273
1274 Ok(iter)
1275}
1276
1277fn build_rejected_example(
1278 request_id: String,
1279 device_id: String,
1280 time: String,
1281 input: ZetaPromptInput,
1282 output: String,
1283 was_shown: bool,
1284 reason: String,
1285 zed_version: Option<String>,
1286) -> Example {
1287 let rejected_patch = build_output_patch(
1288 &input.cursor_path,
1289 input.cursor_excerpt.as_ref(),
1290 &input.editable_range_in_excerpt,
1291 &output,
1292 );
1293 let mut example = build_example_from_snowflake(
1294 request_id,
1295 device_id,
1296 time,
1297 input,
1298 vec![format!("rejection:{}", reason.to_lowercase())],
1299 Some(RejectionInfo { reason, was_shown }),
1300 zed_version,
1301 );
1302 example.spec.rejected_patch = Some(rejected_patch);
1303 example
1304}
1305
1306struct RejectionInfo {
1307 reason: String,
1308 was_shown: bool,
1309}
1310
1311fn build_example_from_snowflake(
1312 request_id: String,
1313 device_id: String,
1314 time: String,
1315 input: ZetaPromptInput,
1316 tags: Vec<String>,
1317 rejection: Option<RejectionInfo>,
1318 zed_version: Option<String>,
1319) -> Example {
1320 let events: Vec<CapturedEvent> = input
1321 .events
1322 .iter()
1323 .map(|event| match event.as_ref() {
1324 zeta_prompt::Event::BufferChange {
1325 path,
1326 old_path,
1327 diff,
1328 predicted,
1329 in_open_source_repo,
1330 } => CapturedEvent {
1331 path: path.clone(),
1332 old_path: old_path.clone(),
1333 diff: diff.clone(),
1334 predicted: *predicted,
1335 in_open_source_repo: *in_open_source_repo,
1336 },
1337 })
1338 .collect();
1339
1340 let related_files: Vec<CapturedRelatedFile> = input
1341 .related_files
1342 .iter()
1343 .map(|rf| CapturedRelatedFile {
1344 path: rf.path.clone(),
1345 max_row: rf.max_row,
1346 excerpts: rf
1347 .excerpts
1348 .iter()
1349 .map(|e| CapturedRelatedExcerpt {
1350 row_range: e.row_range.clone(),
1351 text: e.text.to_string(),
1352 })
1353 .collect(),
1354 })
1355 .collect();
1356
1357 let cursor_excerpt = input.cursor_excerpt.as_ref();
1358 let cursor_offset = input.cursor_offset_in_excerpt;
1359
1360 let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset);
1361
1362 let mut edit_history = String::new();
1363 for event in &input.events {
1364 zeta_prompt::write_event(&mut edit_history, event);
1365 edit_history.push('\n');
1366 }
1367
1368 let (rejection_reason, was_shown) = match &rejection {
1369 Some(r) => (r.reason.clone(), r.was_shown),
1370 None => (String::new(), false),
1371 };
1372
1373 let spec = ExampleSpec {
1374 name: request_id.clone(),
1375 repository_url: String::new(),
1376 revision: String::new(),
1377 tags,
1378 reasoning: None,
1379 uncommitted_diff: String::new(),
1380 cursor_path: input.cursor_path.clone(),
1381 cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1382 edit_history,
1383 expected_patches: Vec::new(),
1384 rejected_patch: None,
1385 captured_prompt_input: Some(CapturedPromptInput {
1386 cursor_file_content: cursor_excerpt.to_string(),
1387 cursor_offset,
1388 cursor_row,
1389 cursor_column,
1390 excerpt_start_row: None,
1391 events,
1392 related_files,
1393 in_open_source_repo: input.in_open_source_repo,
1394 zed_version,
1395 }),
1396 telemetry: Some(TelemetrySource {
1397 request_id,
1398 device_id,
1399 time,
1400 rejection_reason,
1401 was_shown,
1402 }),
1403 human_feedback: Vec::new(),
1404 rating: None,
1405 };
1406
1407 Example {
1408 spec,
1409 prompt_inputs: None,
1410 prompt: None,
1411 predictions: Vec::new(),
1412 score: Vec::new(),
1413 qa: Vec::new(),
1414 state: None,
1415 }
1416}
1417
1418fn compute_row_column(text: &str, offset: usize) -> (u32, u32) {
1419 let mut row = 0u32;
1420 let mut last_newline_offset = 0;
1421 for (i, c) in text.char_indices() {
1422 if i >= offset {
1423 break;
1424 }
1425 if c == '\n' {
1426 row += 1;
1427 last_newline_offset = i + 1;
1428 }
1429 }
1430 let column = (offset - last_newline_offset) as u32;
1431 (row, column)
1432}
1433
1434fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1435 let before = &excerpt[..cursor_offset.min(excerpt.len())];
1436 let after = &excerpt[cursor_offset.min(excerpt.len())..];
1437 format!("{}[CURSOR_POSITION]{}", before, after)
1438}
1439
1440fn build_output_patch(
1441 cursor_path: &std::path::Path,
1442 cursor_excerpt: &str,
1443 editable_range: &std::ops::Range<usize>,
1444 model_output: &str,
1445) -> String {
1446 let old_text = &cursor_excerpt[editable_range.clone()];
1447
1448 let editable_start_row = cursor_excerpt[..editable_range.start]
1449 .chars()
1450 .filter(|&c| c == '\n')
1451 .count() as u32;
1452
1453 let diff_body = language::unified_diff_with_offsets(
1454 old_text,
1455 model_output,
1456 editable_start_row,
1457 editable_start_row,
1458 );
1459
1460 let mut patch = String::new();
1461 writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1462 writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1463 patch.push_str(&diff_body);
1464 patch
1465}
1466
1467pub(crate) fn get_column_indices(
1468 meta: &Option<SnowflakeResultSetMetaData>,
1469 names: &[&str],
1470) -> std::collections::HashMap<String, usize> {
1471 let mut indices = std::collections::HashMap::new();
1472 if let Some(meta) = meta {
1473 for (index, col) in meta.row_type.iter().enumerate() {
1474 for &name in names {
1475 if col.name.eq_ignore_ascii_case(name) {
1476 indices.insert(name.to_string(), index);
1477 }
1478 }
1479 }
1480 }
1481 indices
1482}