1use anyhow::{Context as _, Result};
2use flate2::read::GzDecoder;
3use gpui::BackgroundExecutor;
4use http_client::{AsyncBody, HttpClient, Method, Request};
5use indoc::indoc;
6use serde::Deserialize;
7use serde_json::{Value as JsonValue, json};
8use std::io::Read;
9use std::sync::Arc;
10use std::time::Duration;
11
12use zeta_prompt::ZetaPromptInput;
13
14use crate::example::Example;
15use crate::progress::{InfoStyle, Progress, Step};
16use edit_prediction::example_spec::{
17 CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec,
18 TelemetrySource,
19};
20use std::fmt::Write as _;
21
22const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
23const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
24const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
25const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
26const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
27
28const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
29const POLL_INTERVAL: Duration = Duration::from_secs(2);
30const MAX_POLL_ATTEMPTS: usize = 120;
31
32/// Parse an input token of the form `captured-after:{timestamp}`.
33pub fn parse_captured_after_input(input: &str) -> Option<&str> {
34 input.strip_prefix("captured-after:")
35}
36
37/// Parse an input token of the form `rejected-after:{timestamp}`.
38pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
39 input.strip_prefix("rejected-after:")
40}
41
42/// Parse an input token of the form `requested-after:{timestamp}`.
43pub fn parse_requested_after_input(input: &str) -> Option<&str> {
44 input.strip_prefix("requested-after:")
45}
46
47pub async fn fetch_captured_examples_after(
48 http_client: Arc<dyn HttpClient>,
49 after_timestamps: &[String],
50 max_rows_per_timestamp: usize,
51 background_executor: BackgroundExecutor,
52) -> Result<Vec<Example>> {
53 if after_timestamps.is_empty() {
54 return Ok(Vec::new());
55 }
56
57 let progress = Progress::global();
58
59 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
60 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
61 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
62 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
63 )?;
64 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
65
66 let mut all_examples = Vec::new();
67
68 for after_date in after_timestamps.iter() {
69 let step_progress_name = format!(">{after_date}");
70 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
71 step_progress.set_substatus("querying");
72
73 let statement = indoc! {r#"
74 SELECT
75 event_properties:example AS example
76 FROM events
77 WHERE event_type = ?
78 AND time > TRY_TO_TIMESTAMP_NTZ(?)
79 ORDER BY time ASC
80 LIMIT ?
81 "#};
82
83 let request = json!({
84 "statement": statement,
85 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
86 "database": "EVENTS",
87 "schema": "PUBLIC",
88 "warehouse": "DBT",
89 "role": role,
90 "bindings": {
91 "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
92 "2": { "type": "TEXT", "value": after_date },
93 "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
94 }
95 });
96
97 let response = run_sql_with_polling(
98 http_client.clone(),
99 &base_url,
100 &token,
101 &request,
102 &step_progress,
103 background_executor.clone(),
104 )
105 .await?;
106
107 let total_rows = response
108 .result_set_meta_data
109 .as_ref()
110 .and_then(|m| m.num_rows)
111 .unwrap_or(response.data.len() as i64);
112
113 let num_partitions = response
114 .result_set_meta_data
115 .as_ref()
116 .map(|m| m.partition_info.len())
117 .unwrap_or(1)
118 .max(1);
119
120 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
121 step_progress.set_substatus("parsing");
122
123 all_examples.extend(examples_from_response(&response)?);
124
125 if num_partitions > 1 {
126 let statement_handle = response
127 .statement_handle
128 .as_ref()
129 .context("response has multiple partitions but no statementHandle")?;
130
131 for partition in 1..num_partitions {
132 step_progress.set_substatus(format!(
133 "fetching partition {}/{}",
134 partition + 1,
135 num_partitions
136 ));
137
138 let partition_response = fetch_partition(
139 http_client.clone(),
140 &base_url,
141 &token,
142 statement_handle,
143 partition,
144 )
145 .await?;
146
147 all_examples.extend(examples_from_response(&partition_response)?);
148 }
149 }
150
151 step_progress.set_substatus("done");
152 }
153
154 Ok(all_examples)
155}
156
157#[derive(Debug, Clone, Deserialize)]
158#[serde(rename_all = "camelCase")]
159struct SnowflakeStatementResponse {
160 #[serde(default)]
161 data: Vec<Vec<JsonValue>>,
162 #[serde(default)]
163 result_set_meta_data: Option<SnowflakeResultSetMetaData>,
164 #[serde(default)]
165 code: Option<String>,
166 #[serde(default)]
167 message: Option<String>,
168 #[serde(default)]
169 statement_handle: Option<String>,
170}
171
172#[derive(Debug, Clone, Deserialize)]
173#[serde(rename_all = "camelCase")]
174struct SnowflakeResultSetMetaData {
175 #[serde(default, rename = "rowType")]
176 row_type: Vec<SnowflakeColumnMeta>,
177 #[serde(default)]
178 num_rows: Option<i64>,
179 #[serde(default)]
180 partition_info: Vec<SnowflakePartitionInfo>,
181}
182
183#[derive(Debug, Clone, Deserialize)]
184#[serde(rename_all = "camelCase")]
185struct SnowflakePartitionInfo {}
186
187#[derive(Debug, Clone, Deserialize)]
188struct SnowflakeColumnMeta {
189 #[serde(default)]
190 name: String,
191}
192
193fn examples_from_response(
194 response: &SnowflakeStatementResponse,
195) -> Result<impl Iterator<Item = Example> + '_> {
196 if let Some(code) = &response.code {
197 if code != SNOWFLAKE_SUCCESS_CODE {
198 anyhow::bail!(
199 "snowflake sql api returned error code={code} message={}",
200 response.message.as_deref().unwrap_or("<no message>")
201 );
202 }
203 }
204
205 let example_index = response
206 .result_set_meta_data
207 .as_ref()
208 .and_then(|m| {
209 m.row_type.iter().enumerate().find_map(|(index, col)| {
210 if col.name.eq_ignore_ascii_case("example") {
211 Some(index)
212 } else {
213 None
214 }
215 })
216 })
217 .unwrap_or(0);
218
219 let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
220 let Some(example_value) = data_row.get(example_index) else {
221 return None;
222 };
223 if example_value.is_null() {
224 return None;
225 }
226
227 let parse_result = match example_value {
228 JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
229 _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
230 };
231
232 match parse_result {
233 Ok(spec) => Some(Example {
234 spec,
235 prompt_inputs: None,
236 prompt: None,
237 predictions: Vec::new(),
238 score: Vec::new(),
239 qa: Vec::new(),
240 state: None,
241 }),
242 Err(error) => {
243 let raw_json = serde_json::to_string_pretty(example_value)
244 .unwrap_or_else(|_| "<failed to serialize json>".to_string());
245 log::error!(
246 "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
247 );
248 None
249 }
250 }
251 });
252
253 Ok(iter)
254}
255
256async fn run_sql_with_polling(
257 http_client: Arc<dyn HttpClient>,
258 base_url: &str,
259 token: &str,
260 request: &serde_json::Value,
261 step_progress: &crate::progress::StepProgress,
262 background_executor: BackgroundExecutor,
263) -> Result<SnowflakeStatementResponse> {
264 let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
265
266 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
267 let statement_handle = response
268 .statement_handle
269 .as_ref()
270 .context("async query response missing statementHandle")?
271 .clone();
272
273 for attempt in 1..=MAX_POLL_ATTEMPTS {
274 step_progress.set_substatus(format!("polling ({attempt})"));
275
276 background_executor.timer(POLL_INTERVAL).await;
277
278 response =
279 fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
280
281 if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
282 break;
283 }
284 }
285
286 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
287 anyhow::bail!(
288 "query still running after {} poll attempts ({} seconds)",
289 MAX_POLL_ATTEMPTS,
290 MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
291 );
292 }
293 }
294
295 Ok(response)
296}
297
298async fn fetch_partition(
299 http_client: Arc<dyn HttpClient>,
300 base_url: &str,
301 token: &str,
302 statement_handle: &str,
303 partition: usize,
304) -> Result<SnowflakeStatementResponse> {
305 let url = format!(
306 "{}/api/v2/statements/{}?partition={}",
307 base_url.trim_end_matches('/'),
308 statement_handle,
309 partition
310 );
311
312 let http_request = Request::builder()
313 .method(Method::GET)
314 .uri(url.as_str())
315 .header("Authorization", format!("Bearer {token}"))
316 .header(
317 "X-Snowflake-Authorization-Token-Type",
318 "PROGRAMMATIC_ACCESS_TOKEN",
319 )
320 .header("Accept", "application/json")
321 .header("Accept-Encoding", "gzip")
322 .header("User-Agent", "edit_prediction_cli")
323 .body(AsyncBody::empty())?;
324
325 let response = http_client
326 .send(http_request)
327 .await
328 .context("failed to send partition request to Snowflake SQL API")?;
329
330 let status = response.status();
331 let content_encoding = response
332 .headers()
333 .get("content-encoding")
334 .and_then(|v| v.to_str().ok())
335 .map(|s| s.to_lowercase());
336
337 let body_bytes = {
338 use futures::AsyncReadExt as _;
339
340 let mut body = response.into_body();
341 let mut bytes = Vec::new();
342 body.read_to_end(&mut bytes)
343 .await
344 .context("failed to read Snowflake SQL API partition response body")?;
345 bytes
346 };
347
348 let body_bytes = if content_encoding.as_deref() == Some("gzip") {
349 let mut decoder = GzDecoder::new(&body_bytes[..]);
350 let mut decompressed = Vec::new();
351 decoder
352 .read_to_end(&mut decompressed)
353 .context("failed to decompress gzip response")?;
354 decompressed
355 } else {
356 body_bytes
357 };
358
359 if !status.is_success() && status.as_u16() != 202 {
360 let body_text = String::from_utf8_lossy(&body_bytes);
361 anyhow::bail!(
362 "snowflake sql api partition request http {}: {}",
363 status.as_u16(),
364 body_text
365 );
366 }
367
368 if body_bytes.is_empty() {
369 anyhow::bail!(
370 "snowflake sql api partition {} returned empty response body (http {})",
371 partition,
372 status.as_u16()
373 );
374 }
375
376 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
377 let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
378 format!(
379 "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
380 partition,
381 status.as_u16(),
382 body_preview
383 )
384 })
385}
386
387async fn run_sql(
388 http_client: Arc<dyn HttpClient>,
389 base_url: &str,
390 token: &str,
391 request: &serde_json::Value,
392) -> Result<SnowflakeStatementResponse> {
393 let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
394
395 let request_body =
396 serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
397
398 let http_request = Request::builder()
399 .method(Method::POST)
400 .uri(url.as_str())
401 .header("Authorization", format!("Bearer {token}"))
402 .header(
403 "X-Snowflake-Authorization-Token-Type",
404 "PROGRAMMATIC_ACCESS_TOKEN",
405 )
406 .header("Content-Type", "application/json")
407 .header("Accept", "application/json")
408 .header("User-Agent", "edit_prediction_cli")
409 .body(AsyncBody::from(request_body.clone()))?;
410
411 let response = http_client
412 .send(http_request)
413 .await
414 .context("failed to send request to Snowflake SQL API")?;
415
416 let status = response.status();
417 let body_bytes = {
418 use futures::AsyncReadExt as _;
419
420 let mut body = response.into_body();
421 let mut bytes = Vec::new();
422 body.read_to_end(&mut bytes)
423 .await
424 .context("failed to read Snowflake SQL API response body")?;
425 bytes
426 };
427
428 if !status.is_success() && status.as_u16() != 202 {
429 let body_text = String::from_utf8_lossy(&body_bytes);
430 anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
431 }
432
433 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
434 .context("failed to parse Snowflake SQL API response JSON")
435}
436
437pub async fn fetch_rejected_examples_after(
438 http_client: Arc<dyn HttpClient>,
439 after_timestamps: &[String],
440 max_rows_per_timestamp: usize,
441 background_executor: BackgroundExecutor,
442) -> Result<Vec<Example>> {
443 if after_timestamps.is_empty() {
444 return Ok(Vec::new());
445 }
446
447 let progress = Progress::global();
448
449 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
450 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
451 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
452 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
453 )?;
454 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
455
456 let mut all_examples = Vec::new();
457
458 for after_date in after_timestamps.iter() {
459 let step_progress_name = format!("rejected>{after_date}");
460 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
461 step_progress.set_substatus("querying");
462
463 // Join rejected events with their corresponding request events to get the full context.
464 // We filter for V3 sampling data which contains the structured input we need.
465 // We also filter for predictions that were actually shown to the user (was_shown = true)
466 // to focus on explicit user rejections rather than implicit cancellations.
467 let statement = indoc! {r#"
468 SELECT
469 req.event_properties:request_id::string AS request_id,
470 req.device_id::string AS device_id,
471 req.time::string AS time,
472 req.event_properties:input AS input,
473 req.event_properties:prompt::string AS prompt,
474 req.event_properties:output::string AS output,
475 rej.event_properties:was_shown::boolean AS was_shown,
476 rej.event_properties:reason::string AS reason
477 FROM events req
478 INNER JOIN events rej
479 ON req.event_properties:request_id = rej.event_properties:request_id
480 WHERE req.event_type = ?
481 AND rej.event_type = ?
482 AND req.event_properties:version = 'V3'
483 AND rej.event_properties:was_shown = true
484 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
485 ORDER BY req.time ASC
486 LIMIT ?
487 "#};
488
489 let request = json!({
490 "statement": statement,
491 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
492 "database": "EVENTS",
493 "schema": "PUBLIC",
494 "warehouse": "DBT",
495 "role": role,
496 "bindings": {
497 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
498 "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
499 "3": { "type": "TEXT", "value": after_date },
500 "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
501 }
502 });
503
504 let response = run_sql_with_polling(
505 http_client.clone(),
506 &base_url,
507 &token,
508 &request,
509 &step_progress,
510 background_executor.clone(),
511 )
512 .await?;
513
514 let total_rows = response
515 .result_set_meta_data
516 .as_ref()
517 .and_then(|m| m.num_rows)
518 .unwrap_or(response.data.len() as i64);
519
520 let num_partitions = response
521 .result_set_meta_data
522 .as_ref()
523 .map(|m| m.partition_info.len())
524 .unwrap_or(1)
525 .max(1);
526
527 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
528 step_progress.set_substatus("parsing");
529
530 all_examples.extend(rejected_examples_from_response(&response)?);
531
532 if num_partitions > 1 {
533 let statement_handle = response
534 .statement_handle
535 .as_ref()
536 .context("response has multiple partitions but no statementHandle")?;
537
538 for partition in 1..num_partitions {
539 step_progress.set_substatus(format!(
540 "fetching partition {}/{}",
541 partition + 1,
542 num_partitions
543 ));
544
545 let partition_response = fetch_partition(
546 http_client.clone(),
547 &base_url,
548 &token,
549 statement_handle,
550 partition,
551 )
552 .await?;
553
554 all_examples.extend(rejected_examples_from_response(&partition_response)?);
555 }
556 }
557
558 step_progress.set_substatus("done");
559 }
560
561 Ok(all_examples)
562}
563
564pub async fn fetch_requested_examples_after(
565 http_client: Arc<dyn HttpClient>,
566 after_timestamps: &[String],
567 max_rows_per_timestamp: usize,
568 background_executor: BackgroundExecutor,
569) -> Result<Vec<Example>> {
570 if after_timestamps.is_empty() {
571 return Ok(Vec::new());
572 }
573
574 let progress = Progress::global();
575
576 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
577 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
578 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
579 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
580 )?;
581 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
582
583 let mut all_examples = Vec::new();
584
585 for after_date in after_timestamps.iter() {
586 let step_progress_name = format!("requested>{after_date}");
587 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
588 step_progress.set_substatus("querying");
589
590 let statement = indoc! {r#"
591 SELECT
592 req.event_properties:request_id::string AS request_id,
593 req.device_id::string AS device_id,
594 req.time::string AS time,
595 req.event_properties:input AS input
596 FROM events req
597 WHERE req.event_type = ?
598 AND req.event_properties:version = 'V3'
599 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
600 ORDER BY req.time ASC
601 LIMIT ?
602 "#};
603
604 let request = json!({
605 "statement": statement,
606 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
607 "database": "EVENTS",
608 "schema": "PUBLIC",
609 "warehouse": "DBT",
610 "role": role,
611 "bindings": {
612 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
613 "2": { "type": "TEXT", "value": after_date },
614 "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
615 }
616 });
617
618 let response = run_sql_with_polling(
619 http_client.clone(),
620 &base_url,
621 &token,
622 &request,
623 &step_progress,
624 background_executor.clone(),
625 )
626 .await?;
627
628 let total_rows = response
629 .result_set_meta_data
630 .as_ref()
631 .and_then(|m| m.num_rows)
632 .unwrap_or(response.data.len() as i64);
633
634 let num_partitions = response
635 .result_set_meta_data
636 .as_ref()
637 .map(|m| m.partition_info.len())
638 .unwrap_or(1)
639 .max(1);
640
641 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
642 step_progress.set_substatus("parsing");
643
644 let column_indices = get_column_indices(
645 &response.result_set_meta_data,
646 &["request_id", "device_id", "time", "input"],
647 );
648
649 all_examples.extend(requested_examples_from_response(
650 &response,
651 &column_indices,
652 )?);
653
654 if num_partitions > 1 {
655 let statement_handle = response
656 .statement_handle
657 .as_ref()
658 .context("response has multiple partitions but no statementHandle")?;
659
660 for partition in 1..num_partitions {
661 step_progress.set_substatus(format!(
662 "fetching partition {}/{}",
663 partition + 1,
664 num_partitions
665 ));
666
667 let partition_response = fetch_partition(
668 http_client.clone(),
669 &base_url,
670 &token,
671 statement_handle,
672 partition,
673 )
674 .await?;
675
676 all_examples.extend(requested_examples_from_response(
677 &partition_response,
678 &column_indices,
679 )?);
680 }
681 }
682
683 step_progress.set_substatus("done");
684 }
685
686 Ok(all_examples)
687}
688
689fn requested_examples_from_response<'a>(
690 response: &'a SnowflakeStatementResponse,
691 column_indices: &'a std::collections::HashMap<String, usize>,
692) -> Result<impl Iterator<Item = Example> + 'a> {
693 if let Some(code) = &response.code {
694 if code != SNOWFLAKE_SUCCESS_CODE {
695 anyhow::bail!(
696 "snowflake sql api returned error code={code} message={}",
697 response.message.as_deref().unwrap_or("<no message>")
698 );
699 }
700 }
701
702 let iter = response
703 .data
704 .iter()
705 .enumerate()
706 .filter_map(move |(row_index, data_row)| {
707 let get_string = |name: &str| -> Option<String> {
708 let index = column_indices.get(name).copied()?;
709 match data_row.get(index)? {
710 JsonValue::String(s) => Some(s.clone()),
711 JsonValue::Null => None,
712 other => Some(other.to_string()),
713 }
714 };
715
716 let get_json = |name: &str| -> Option<JsonValue> {
717 let index = column_indices.get(name).copied()?;
718 let value = data_row.get(index)?;
719 if value.is_null() {
720 return None;
721 }
722 match value {
723 JsonValue::String(s) => serde_json::from_str(s).ok(),
724 other => Some(other.clone()),
725 }
726 };
727
728 let request_id_str = get_string("request_id");
729 let device_id = get_string("device_id");
730 let time = get_string("time");
731 let input_json = get_json("input");
732 let input: Option<ZetaPromptInput> =
733 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
734
735 match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
736 (Some(request_id), Some(device_id), Some(time), Some(input)) => {
737 Some(build_example_from_snowflake(
738 request_id,
739 device_id,
740 time,
741 input,
742 vec!["requested".to_string()],
743 None,
744 ))
745 }
746 _ => {
747 log::warn!(
748 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
749 request_id_str.is_some(),
750 device_id.is_some(),
751 time.is_some(),
752 input_json.is_some(),
753 );
754 None
755 }
756 }
757 });
758
759 Ok(iter)
760}
761
762fn rejected_examples_from_response(
763 response: &SnowflakeStatementResponse,
764) -> Result<impl Iterator<Item = Example> + '_> {
765 if let Some(code) = &response.code {
766 if code != SNOWFLAKE_SUCCESS_CODE {
767 anyhow::bail!(
768 "snowflake sql api returned error code={code} message={}",
769 response.message.as_deref().unwrap_or("<no message>")
770 );
771 }
772 }
773
774 let column_indices = get_column_indices(
775 &response.result_set_meta_data,
776 &[
777 "request_id",
778 "device_id",
779 "time",
780 "input",
781 "prompt",
782 "output",
783 "was_shown",
784 "reason",
785 ],
786 );
787
788 let iter = response
789 .data
790 .iter()
791 .enumerate()
792 .filter_map(move |(row_index, data_row)| {
793 let get_string = |name: &str| -> Option<String> {
794 let index = column_indices.get(name).copied()?;
795 match data_row.get(index)? {
796 JsonValue::String(s) => Some(s.clone()),
797 JsonValue::Null => None,
798 other => Some(other.to_string()),
799 }
800 };
801
802 let get_json = |name: &str| -> Option<JsonValue> {
803 let index = column_indices.get(name).copied()?;
804 let value = data_row.get(index)?;
805 if value.is_null() {
806 return None;
807 }
808 match value {
809 JsonValue::String(s) => serde_json::from_str(s).ok(),
810 other => Some(other.clone()),
811 }
812 };
813
814 let get_bool = |name: &str| -> Option<bool> {
815 let index = column_indices.get(name).copied()?;
816 match data_row.get(index)? {
817 JsonValue::Bool(b) => Some(*b),
818 JsonValue::String(s) => s.parse().ok(),
819 _ => None,
820 }
821 };
822
823 let request_id_str = get_string("request_id");
824 let device_id = get_string("device_id");
825 let time = get_string("time");
826 let input_json = get_json("input");
827 let input: Option<ZetaPromptInput> =
828 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
829 let output = get_string("output");
830 let was_shown = get_bool("was_shown");
831 let reason = get_string("reason");
832
833 match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
834 (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
835 Some(build_rejected_example(
836 request_id,
837 device_id,
838 time,
839 input,
840 output,
841 was_shown,
842 reason,
843 ))
844 }
845 _ => {
846 log::warn!(
847 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
848 request_id_str.is_some(),
849 device_id.is_some(),
850 time.is_some(),
851 input_json.is_some(),
852 output.is_some(),
853 was_shown.is_some(),
854 reason.is_some()
855 );
856 None
857 }
858 }
859 });
860
861 Ok(iter)
862}
863
864fn build_rejected_example(
865 request_id: String,
866 device_id: String,
867 time: String,
868 input: ZetaPromptInput,
869 output: String,
870 was_shown: bool,
871 reason: String,
872) -> Example {
873 let rejected_patch = build_output_patch(
874 &input.cursor_path,
875 input.cursor_excerpt.as_ref(),
876 &input.editable_range_in_excerpt,
877 &output,
878 );
879 let mut example = build_example_from_snowflake(
880 request_id,
881 device_id,
882 time,
883 input,
884 vec![format!("rejection:{}", reason.to_lowercase())],
885 Some(RejectionInfo { reason, was_shown }),
886 );
887 example.spec.rejected_patch = Some(rejected_patch);
888 example
889}
890
891struct RejectionInfo {
892 reason: String,
893 was_shown: bool,
894}
895
896fn build_example_from_snowflake(
897 request_id: String,
898 device_id: String,
899 time: String,
900 input: ZetaPromptInput,
901 tags: Vec<String>,
902 rejection: Option<RejectionInfo>,
903) -> Example {
904 let events: Vec<CapturedEvent> = input
905 .events
906 .iter()
907 .map(|event| match event.as_ref() {
908 zeta_prompt::Event::BufferChange {
909 path,
910 old_path,
911 diff,
912 predicted,
913 in_open_source_repo,
914 } => CapturedEvent {
915 path: path.clone(),
916 old_path: old_path.clone(),
917 diff: diff.clone(),
918 predicted: *predicted,
919 in_open_source_repo: *in_open_source_repo,
920 },
921 })
922 .collect();
923
924 let related_files: Vec<CapturedRelatedFile> = input
925 .related_files
926 .iter()
927 .map(|rf| CapturedRelatedFile {
928 path: rf.path.clone(),
929 max_row: rf.max_row,
930 excerpts: rf
931 .excerpts
932 .iter()
933 .map(|e| CapturedRelatedExcerpt {
934 row_range: e.row_range.clone(),
935 text: e.text.to_string(),
936 })
937 .collect(),
938 })
939 .collect();
940
941 let cursor_excerpt = input.cursor_excerpt.as_ref();
942 let cursor_offset = input.cursor_offset_in_excerpt;
943
944 let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset);
945
946 let mut edit_history = String::new();
947 for event in &input.events {
948 zeta_prompt::write_event(&mut edit_history, event);
949 edit_history.push('\n');
950 }
951
952 let (rejection_reason, was_shown) = match &rejection {
953 Some(r) => (r.reason.clone(), r.was_shown),
954 None => (String::new(), false),
955 };
956
957 let spec = ExampleSpec {
958 name: request_id.clone(),
959 repository_url: String::new(),
960 revision: String::new(),
961 tags,
962 reasoning: None,
963 uncommitted_diff: String::new(),
964 cursor_path: input.cursor_path.clone(),
965 cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
966 edit_history,
967 expected_patches: Vec::new(),
968 rejected_patch: None,
969 captured_prompt_input: Some(CapturedPromptInput {
970 cursor_file_content: cursor_excerpt.to_string(),
971 cursor_offset,
972 cursor_row,
973 cursor_column,
974 events,
975 related_files,
976 }),
977 telemetry: Some(TelemetrySource {
978 request_id,
979 device_id,
980 time,
981 rejection_reason,
982 was_shown,
983 }),
984 };
985
986 Example {
987 spec,
988 prompt_inputs: None,
989 prompt: None,
990 predictions: Vec::new(),
991 score: Vec::new(),
992 qa: Vec::new(),
993 state: None,
994 }
995}
996
997fn compute_row_column(text: &str, offset: usize) -> (u32, u32) {
998 let mut row = 0u32;
999 let mut last_newline_offset = 0;
1000 for (i, c) in text.char_indices() {
1001 if i >= offset {
1002 break;
1003 }
1004 if c == '\n' {
1005 row += 1;
1006 last_newline_offset = i + 1;
1007 }
1008 }
1009 let column = (offset - last_newline_offset) as u32;
1010 (row, column)
1011}
1012
1013fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1014 let before = &excerpt[..cursor_offset.min(excerpt.len())];
1015 let after = &excerpt[cursor_offset.min(excerpt.len())..];
1016 format!("{}[CURSOR_POSITION]{}", before, after)
1017}
1018
1019fn build_output_patch(
1020 cursor_path: &std::path::Path,
1021 cursor_excerpt: &str,
1022 editable_range: &std::ops::Range<usize>,
1023 model_output: &str,
1024) -> String {
1025 let old_text = &cursor_excerpt[editable_range.clone()];
1026
1027 let editable_start_row = cursor_excerpt[..editable_range.start]
1028 .chars()
1029 .filter(|&c| c == '\n')
1030 .count() as u32;
1031
1032 let diff_body = language::unified_diff_with_offsets(
1033 old_text,
1034 model_output,
1035 editable_start_row,
1036 editable_start_row,
1037 );
1038
1039 let mut patch = String::new();
1040 writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1041 writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1042 patch.push_str(&diff_body);
1043 patch
1044}
1045
1046fn get_column_indices(
1047 meta: &Option<SnowflakeResultSetMetaData>,
1048 names: &[&str],
1049) -> std::collections::HashMap<String, usize> {
1050 let mut indices = std::collections::HashMap::new();
1051 if let Some(meta) = meta {
1052 for (index, col) in meta.row_type.iter().enumerate() {
1053 for &name in names {
1054 if col.name.eq_ignore_ascii_case(name) {
1055 indices.insert(name.to_string(), index);
1056 }
1057 }
1058 }
1059 }
1060 indices
1061}