1use anyhow::{Context as _, Result};
2use flate2::read::GzDecoder;
3use gpui::BackgroundExecutor;
4use http_client::{AsyncBody, HttpClient, Method, Request};
5use indoc::indoc;
6use serde::Deserialize;
7use serde_json::{Value as JsonValue, json};
8use std::io::Read;
9use std::sync::Arc;
10use std::time::Duration;
11use telemetry_events::EditPredictionRating;
12
13use zeta_prompt::ZetaPromptInput;
14
15use crate::example::Example;
16use crate::progress::{InfoStyle, Progress, Step};
17use edit_prediction::example_spec::{
18 CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec,
19 TelemetrySource,
20};
21use std::fmt::Write as _;
22
23const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
24const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
25const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
26const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
27const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
28const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
29
30const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
31const POLL_INTERVAL: Duration = Duration::from_secs(2);
32const MAX_POLL_ATTEMPTS: usize = 120;
33
34/// Parse an input token of the form `captured-after:{timestamp}`.
35pub fn parse_captured_after_input(input: &str) -> Option<&str> {
36 input.strip_prefix("captured-after:")
37}
38
39/// Parse an input token of the form `rejected-after:{timestamp}`.
40pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
41 input.strip_prefix("rejected-after:")
42}
43
44/// Parse an input token of the form `requested-after:{timestamp}`.
45pub fn parse_requested_after_input(input: &str) -> Option<&str> {
46 input.strip_prefix("requested-after:")
47}
48
49/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
50/// or `rated-negative-after:{timestamp}`.
51/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
52pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
53 if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
54 Some((timestamp, Some(EditPredictionRating::Positive)))
55 } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
56 Some((timestamp, Some(EditPredictionRating::Negative)))
57 } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
58 Some((timestamp, None))
59 } else {
60 None
61 }
62}
63
64pub async fn fetch_captured_examples_after(
65 http_client: Arc<dyn HttpClient>,
66 after_timestamps: &[String],
67 max_rows_per_timestamp: usize,
68 background_executor: BackgroundExecutor,
69) -> Result<Vec<Example>> {
70 if after_timestamps.is_empty() {
71 return Ok(Vec::new());
72 }
73
74 let progress = Progress::global();
75
76 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
77 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
78 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
79 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
80 )?;
81 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
82
83 let mut all_examples = Vec::new();
84
85 for after_date in after_timestamps.iter() {
86 let step_progress_name = format!(">{after_date}");
87 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
88 step_progress.set_substatus("querying");
89
90 let statement = indoc! {r#"
91 SELECT
92 event_properties:example AS example
93 FROM events
94 WHERE event_type = ?
95 AND time > TRY_TO_TIMESTAMP_NTZ(?)
96 ORDER BY time ASC
97 LIMIT ?
98 "#};
99
100 let request = json!({
101 "statement": statement,
102 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
103 "database": "EVENTS",
104 "schema": "PUBLIC",
105 "warehouse": "DBT",
106 "role": role,
107 "bindings": {
108 "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
109 "2": { "type": "TEXT", "value": after_date },
110 "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
111 }
112 });
113
114 let response = run_sql_with_polling(
115 http_client.clone(),
116 &base_url,
117 &token,
118 &request,
119 &step_progress,
120 background_executor.clone(),
121 )
122 .await?;
123
124 let total_rows = response
125 .result_set_meta_data
126 .as_ref()
127 .and_then(|m| m.num_rows)
128 .unwrap_or(response.data.len() as i64);
129
130 let num_partitions = response
131 .result_set_meta_data
132 .as_ref()
133 .map(|m| m.partition_info.len())
134 .unwrap_or(1)
135 .max(1);
136
137 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
138 step_progress.set_substatus("parsing");
139
140 let example_index = response
141 .result_set_meta_data
142 .as_ref()
143 .and_then(|m| {
144 m.row_type.iter().enumerate().find_map(|(index, col)| {
145 if col.name.eq_ignore_ascii_case("example") {
146 Some(index)
147 } else {
148 None
149 }
150 })
151 })
152 .unwrap_or(0);
153
154 all_examples.extend(examples_from_response(&response, example_index)?);
155
156 if num_partitions > 1 {
157 let statement_handle = response
158 .statement_handle
159 .as_ref()
160 .context("response has multiple partitions but no statementHandle")?;
161
162 for partition in 1..num_partitions {
163 step_progress.set_substatus(format!(
164 "fetching partition {}/{}",
165 partition + 1,
166 num_partitions
167 ));
168
169 let partition_response = fetch_partition(
170 http_client.clone(),
171 &base_url,
172 &token,
173 statement_handle,
174 partition,
175 )
176 .await?;
177
178 all_examples.extend(examples_from_response(&partition_response, example_index)?);
179 }
180 }
181
182 step_progress.set_substatus("done");
183 }
184
185 Ok(all_examples)
186}
187
188#[derive(Debug, Clone, Deserialize)]
189#[serde(rename_all = "camelCase")]
190struct SnowflakeStatementResponse {
191 #[serde(default)]
192 data: Vec<Vec<JsonValue>>,
193 #[serde(default)]
194 result_set_meta_data: Option<SnowflakeResultSetMetaData>,
195 #[serde(default)]
196 code: Option<String>,
197 #[serde(default)]
198 message: Option<String>,
199 #[serde(default)]
200 statement_handle: Option<String>,
201}
202
203#[derive(Debug, Clone, Deserialize)]
204#[serde(rename_all = "camelCase")]
205struct SnowflakeResultSetMetaData {
206 #[serde(default, rename = "rowType")]
207 row_type: Vec<SnowflakeColumnMeta>,
208 #[serde(default)]
209 num_rows: Option<i64>,
210 #[serde(default)]
211 partition_info: Vec<SnowflakePartitionInfo>,
212}
213
214#[derive(Debug, Clone, Deserialize)]
215#[serde(rename_all = "camelCase")]
216struct SnowflakePartitionInfo {}
217
218#[derive(Debug, Clone, Deserialize)]
219struct SnowflakeColumnMeta {
220 #[serde(default)]
221 name: String,
222}
223
224fn examples_from_response(
225 response: &SnowflakeStatementResponse,
226 example_index: usize,
227) -> Result<impl Iterator<Item = Example> + '_> {
228 if let Some(code) = &response.code {
229 if code != SNOWFLAKE_SUCCESS_CODE {
230 anyhow::bail!(
231 "snowflake sql api returned error code={code} message={}",
232 response.message.as_deref().unwrap_or("<no message>")
233 );
234 }
235 }
236
237 let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
238 let Some(example_value) = data_row.get(example_index) else {
239 return None;
240 };
241 if example_value.is_null() {
242 return None;
243 }
244
245 let parse_result = match example_value {
246 JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
247 _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
248 };
249
250 match parse_result {
251 Ok(spec) => Some(Example {
252 spec,
253 prompt_inputs: None,
254 prompt: None,
255 predictions: Vec::new(),
256 score: Vec::new(),
257 qa: Vec::new(),
258 state: None,
259 }),
260 Err(error) => {
261 let raw_json = serde_json::to_string_pretty(example_value)
262 .unwrap_or_else(|_| "<failed to serialize json>".to_string());
263 log::error!(
264 "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
265 );
266 None
267 }
268 }
269 });
270
271 Ok(iter)
272}
273
274async fn run_sql_with_polling(
275 http_client: Arc<dyn HttpClient>,
276 base_url: &str,
277 token: &str,
278 request: &serde_json::Value,
279 step_progress: &crate::progress::StepProgress,
280 background_executor: BackgroundExecutor,
281) -> Result<SnowflakeStatementResponse> {
282 let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
283
284 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
285 let statement_handle = response
286 .statement_handle
287 .as_ref()
288 .context("async query response missing statementHandle")?
289 .clone();
290
291 for attempt in 1..=MAX_POLL_ATTEMPTS {
292 step_progress.set_substatus(format!("polling ({attempt})"));
293
294 background_executor.timer(POLL_INTERVAL).await;
295
296 response =
297 fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
298
299 if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
300 break;
301 }
302 }
303
304 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
305 anyhow::bail!(
306 "query still running after {} poll attempts ({} seconds)",
307 MAX_POLL_ATTEMPTS,
308 MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
309 );
310 }
311 }
312
313 Ok(response)
314}
315
316async fn fetch_partition(
317 http_client: Arc<dyn HttpClient>,
318 base_url: &str,
319 token: &str,
320 statement_handle: &str,
321 partition: usize,
322) -> Result<SnowflakeStatementResponse> {
323 let url = format!(
324 "{}/api/v2/statements/{}?partition={}",
325 base_url.trim_end_matches('/'),
326 statement_handle,
327 partition
328 );
329
330 let http_request = Request::builder()
331 .method(Method::GET)
332 .uri(url.as_str())
333 .header("Authorization", format!("Bearer {token}"))
334 .header(
335 "X-Snowflake-Authorization-Token-Type",
336 "PROGRAMMATIC_ACCESS_TOKEN",
337 )
338 .header("Accept", "application/json")
339 .header("Accept-Encoding", "gzip")
340 .header("User-Agent", "edit_prediction_cli")
341 .body(AsyncBody::empty())?;
342
343 let response = http_client
344 .send(http_request)
345 .await
346 .context("failed to send partition request to Snowflake SQL API")?;
347
348 let status = response.status();
349 let content_encoding = response
350 .headers()
351 .get("content-encoding")
352 .and_then(|v| v.to_str().ok())
353 .map(|s| s.to_lowercase());
354
355 let body_bytes = {
356 use futures::AsyncReadExt as _;
357
358 let mut body = response.into_body();
359 let mut bytes = Vec::new();
360 body.read_to_end(&mut bytes)
361 .await
362 .context("failed to read Snowflake SQL API partition response body")?;
363 bytes
364 };
365
366 let body_bytes = if content_encoding.as_deref() == Some("gzip") {
367 let mut decoder = GzDecoder::new(&body_bytes[..]);
368 let mut decompressed = Vec::new();
369 decoder
370 .read_to_end(&mut decompressed)
371 .context("failed to decompress gzip response")?;
372 decompressed
373 } else {
374 body_bytes
375 };
376
377 if !status.is_success() && status.as_u16() != 202 {
378 let body_text = String::from_utf8_lossy(&body_bytes);
379 anyhow::bail!(
380 "snowflake sql api partition request http {}: {}",
381 status.as_u16(),
382 body_text
383 );
384 }
385
386 if body_bytes.is_empty() {
387 anyhow::bail!(
388 "snowflake sql api partition {} returned empty response body (http {})",
389 partition,
390 status.as_u16()
391 );
392 }
393
394 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
395 let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
396 format!(
397 "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
398 partition,
399 status.as_u16(),
400 body_preview
401 )
402 })
403}
404
405async fn run_sql(
406 http_client: Arc<dyn HttpClient>,
407 base_url: &str,
408 token: &str,
409 request: &serde_json::Value,
410) -> Result<SnowflakeStatementResponse> {
411 let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
412
413 let request_body =
414 serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
415
416 let http_request = Request::builder()
417 .method(Method::POST)
418 .uri(url.as_str())
419 .header("Authorization", format!("Bearer {token}"))
420 .header(
421 "X-Snowflake-Authorization-Token-Type",
422 "PROGRAMMATIC_ACCESS_TOKEN",
423 )
424 .header("Content-Type", "application/json")
425 .header("Accept", "application/json")
426 .header("User-Agent", "edit_prediction_cli")
427 .body(AsyncBody::from(request_body.clone()))?;
428
429 let response = http_client
430 .send(http_request)
431 .await
432 .context("failed to send request to Snowflake SQL API")?;
433
434 let status = response.status();
435 let body_bytes = {
436 use futures::AsyncReadExt as _;
437
438 let mut body = response.into_body();
439 let mut bytes = Vec::new();
440 body.read_to_end(&mut bytes)
441 .await
442 .context("failed to read Snowflake SQL API response body")?;
443 bytes
444 };
445
446 if !status.is_success() && status.as_u16() != 202 {
447 let body_text = String::from_utf8_lossy(&body_bytes);
448 anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
449 }
450
451 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
452 .context("failed to parse Snowflake SQL API response JSON")
453}
454
455pub async fn fetch_rejected_examples_after(
456 http_client: Arc<dyn HttpClient>,
457 after_timestamps: &[String],
458 max_rows_per_timestamp: usize,
459 background_executor: BackgroundExecutor,
460) -> Result<Vec<Example>> {
461 if after_timestamps.is_empty() {
462 return Ok(Vec::new());
463 }
464
465 let progress = Progress::global();
466
467 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
468 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
469 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
470 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
471 )?;
472 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
473
474 let mut all_examples = Vec::new();
475
476 for after_date in after_timestamps.iter() {
477 let step_progress_name = format!("rejected>{after_date}");
478 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
479 step_progress.set_substatus("querying");
480
481 // Join rejected events with their corresponding request events to get the full context.
482 // We filter for V3 sampling data which contains the structured input we need.
483 // We also filter for predictions that were actually shown to the user (was_shown = true)
484 // to focus on explicit user rejections rather than implicit cancellations.
485 let statement = indoc! {r#"
486 SELECT
487 req.event_properties:request_id::string AS request_id,
488 req.device_id::string AS device_id,
489 req.time::string AS time,
490 req.event_properties:input AS input,
491 req.event_properties:prompt::string AS prompt,
492 req.event_properties:output::string AS output,
493 rej.event_properties:was_shown::boolean AS was_shown,
494 rej.event_properties:reason::string AS reason
495 FROM events req
496 INNER JOIN events rej
497 ON req.event_properties:request_id = rej.event_properties:request_id
498 WHERE req.event_type = ?
499 AND rej.event_type = ?
500 AND req.event_properties:version = 'V3'
501 AND rej.event_properties:was_shown = true
502 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
503 ORDER BY req.time ASC
504 LIMIT ?
505 "#};
506
507 let request = json!({
508 "statement": statement,
509 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
510 "database": "EVENTS",
511 "schema": "PUBLIC",
512 "warehouse": "DBT",
513 "role": role,
514 "bindings": {
515 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
516 "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
517 "3": { "type": "TEXT", "value": after_date },
518 "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
519 }
520 });
521
522 let response = run_sql_with_polling(
523 http_client.clone(),
524 &base_url,
525 &token,
526 &request,
527 &step_progress,
528 background_executor.clone(),
529 )
530 .await?;
531
532 let total_rows = response
533 .result_set_meta_data
534 .as_ref()
535 .and_then(|m| m.num_rows)
536 .unwrap_or(response.data.len() as i64);
537
538 let num_partitions = response
539 .result_set_meta_data
540 .as_ref()
541 .map(|m| m.partition_info.len())
542 .unwrap_or(1)
543 .max(1);
544
545 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
546 step_progress.set_substatus("parsing");
547
548 let column_indices = get_column_indices(
549 &response.result_set_meta_data,
550 &[
551 "request_id",
552 "device_id",
553 "time",
554 "input",
555 "prompt",
556 "output",
557 "was_shown",
558 "reason",
559 ],
560 );
561
562 all_examples.extend(rejected_examples_from_response(&response, &column_indices)?);
563
564 if num_partitions > 1 {
565 let statement_handle = response
566 .statement_handle
567 .as_ref()
568 .context("response has multiple partitions but no statementHandle")?;
569
570 for partition in 1..num_partitions {
571 step_progress.set_substatus(format!(
572 "fetching partition {}/{}",
573 partition + 1,
574 num_partitions
575 ));
576
577 let partition_response = fetch_partition(
578 http_client.clone(),
579 &base_url,
580 &token,
581 statement_handle,
582 partition,
583 )
584 .await?;
585
586 all_examples.extend(rejected_examples_from_response(
587 &partition_response,
588 &column_indices,
589 )?);
590 }
591 }
592
593 step_progress.set_substatus("done");
594 }
595
596 Ok(all_examples)
597}
598
599pub async fn fetch_requested_examples_after(
600 http_client: Arc<dyn HttpClient>,
601 after_timestamps: &[String],
602 max_rows_per_timestamp: usize,
603 background_executor: BackgroundExecutor,
604) -> Result<Vec<Example>> {
605 if after_timestamps.is_empty() {
606 return Ok(Vec::new());
607 }
608
609 let progress = Progress::global();
610
611 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
612 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
613 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
614 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
615 )?;
616 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
617
618 let mut all_examples = Vec::new();
619
620 for after_date in after_timestamps.iter() {
621 let step_progress_name = format!("requested>{after_date}");
622 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
623 step_progress.set_substatus("querying");
624
625 let statement = indoc! {r#"
626 SELECT
627 req.event_properties:request_id::string AS request_id,
628 req.device_id::string AS device_id,
629 req.time::string AS time,
630 req.event_properties:input AS input
631 FROM events req
632 WHERE req.event_type = ?
633 AND req.event_properties:version = 'V3'
634 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
635 ORDER BY req.time ASC
636 LIMIT ?
637 "#};
638
639 let request = json!({
640 "statement": statement,
641 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
642 "database": "EVENTS",
643 "schema": "PUBLIC",
644 "warehouse": "DBT",
645 "role": role,
646 "bindings": {
647 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
648 "2": { "type": "TEXT", "value": after_date },
649 "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
650 }
651 });
652
653 let response = run_sql_with_polling(
654 http_client.clone(),
655 &base_url,
656 &token,
657 &request,
658 &step_progress,
659 background_executor.clone(),
660 )
661 .await?;
662
663 let total_rows = response
664 .result_set_meta_data
665 .as_ref()
666 .and_then(|m| m.num_rows)
667 .unwrap_or(response.data.len() as i64);
668
669 let num_partitions = response
670 .result_set_meta_data
671 .as_ref()
672 .map(|m| m.partition_info.len())
673 .unwrap_or(1)
674 .max(1);
675
676 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
677 step_progress.set_substatus("parsing");
678
679 let column_indices = get_column_indices(
680 &response.result_set_meta_data,
681 &["request_id", "device_id", "time", "input"],
682 );
683
684 all_examples.extend(requested_examples_from_response(
685 &response,
686 &column_indices,
687 )?);
688
689 if num_partitions > 1 {
690 let statement_handle = response
691 .statement_handle
692 .as_ref()
693 .context("response has multiple partitions but no statementHandle")?;
694
695 for partition in 1..num_partitions {
696 step_progress.set_substatus(format!(
697 "fetching partition {}/{}",
698 partition + 1,
699 num_partitions
700 ));
701
702 let partition_response = fetch_partition(
703 http_client.clone(),
704 &base_url,
705 &token,
706 statement_handle,
707 partition,
708 )
709 .await?;
710
711 all_examples.extend(requested_examples_from_response(
712 &partition_response,
713 &column_indices,
714 )?);
715 }
716 }
717
718 step_progress.set_substatus("done");
719 }
720
721 Ok(all_examples)
722}
723
724pub async fn fetch_rated_examples_after(
725 http_client: Arc<dyn HttpClient>,
726 inputs: &[(String, Option<EditPredictionRating>)],
727 max_rows_per_timestamp: usize,
728 background_executor: BackgroundExecutor,
729) -> Result<Vec<Example>> {
730 if inputs.is_empty() {
731 return Ok(Vec::new());
732 }
733
734 let progress = Progress::global();
735
736 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
737 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
738 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
739 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
740 )?;
741 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
742
743 let mut all_examples = Vec::new();
744
745 for (after_date, rating_filter) in inputs.iter() {
746 let filter_label = match rating_filter {
747 None => "",
748 Some(EditPredictionRating::Positive) => ":positive",
749 Some(EditPredictionRating::Negative) => ":negative",
750 };
751 let step_progress_name = format!("rated{filter_label}>{after_date}");
752 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
753 step_progress.set_substatus("querying");
754
755 let rating_value = rating_filter.as_ref().map(|r| match r {
756 EditPredictionRating::Positive => "Positive",
757 EditPredictionRating::Negative => "Negative",
758 });
759
760 let statement = indoc! {r#"
761 SELECT
762 event_properties:inputs AS inputs,
763 event_properties:output::string AS output,
764 event_properties:rating::string AS rating,
765 event_properties:feedback::string AS feedback,
766 device_id::string AS device_id,
767 time::string AS time
768 FROM events
769 WHERE event_type = ?
770 AND (? IS NULL OR event_properties:rating::string = ?)
771 AND time > TRY_TO_TIMESTAMP_NTZ(?)
772 AND event_properties:inputs IS NOT NULL
773 AND event_properties:inputs:cursor_excerpt IS NOT NULL
774 AND event_properties:output IS NOT NULL
775 ORDER BY time ASC
776 LIMIT ?
777 "#};
778
779 let bindings = json!({
780 "1": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
781 "2": { "type": "TEXT", "value": rating_value },
782 "3": { "type": "TEXT", "value": rating_value },
783 "4": { "type": "TEXT", "value": after_date },
784 "5": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
785 });
786
787 let request = json!({
788 "statement": statement,
789 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
790 "database": "EVENTS",
791 "schema": "PUBLIC",
792 "warehouse": "DBT",
793 "role": role,
794 "bindings": bindings
795 });
796
797 let response = run_sql_with_polling(
798 http_client.clone(),
799 &base_url,
800 &token,
801 &request,
802 &step_progress,
803 background_executor.clone(),
804 )
805 .await?;
806
807 let total_rows = response
808 .result_set_meta_data
809 .as_ref()
810 .and_then(|m| m.num_rows)
811 .unwrap_or(response.data.len() as i64);
812
813 let num_partitions = response
814 .result_set_meta_data
815 .as_ref()
816 .map(|m| m.partition_info.len())
817 .unwrap_or(1)
818 .max(1);
819
820 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
821 step_progress.set_substatus("parsing");
822
823 let column_indices = get_column_indices(
824 &response.result_set_meta_data,
825 &[
826 "inputs",
827 "output",
828 "rating",
829 "feedback",
830 "device_id",
831 "time",
832 ],
833 );
834
835 all_examples.extend(rated_examples_from_response(&response, &column_indices)?);
836
837 if num_partitions > 1 {
838 let statement_handle = response
839 .statement_handle
840 .as_ref()
841 .context("response has multiple partitions but no statementHandle")?;
842
843 for partition in 1..num_partitions {
844 step_progress.set_substatus(format!(
845 "fetching partition {}/{}",
846 partition + 1,
847 num_partitions
848 ));
849
850 let partition_response = fetch_partition(
851 http_client.clone(),
852 &base_url,
853 &token,
854 statement_handle,
855 partition,
856 )
857 .await?;
858
859 all_examples.extend(rated_examples_from_response(
860 &partition_response,
861 &column_indices,
862 )?);
863 }
864 }
865
866 step_progress.set_substatus("done");
867 }
868
869 Ok(all_examples)
870}
871
872fn rated_examples_from_response<'a>(
873 response: &'a SnowflakeStatementResponse,
874 column_indices: &'a std::collections::HashMap<String, usize>,
875) -> Result<impl Iterator<Item = Example> + 'a> {
876 if let Some(code) = &response.code {
877 if code != SNOWFLAKE_SUCCESS_CODE {
878 anyhow::bail!(
879 "snowflake sql api returned error code={code} message={}",
880 response.message.as_deref().unwrap_or("<no message>")
881 );
882 }
883 }
884
885 let iter = response
886 .data
887 .iter()
888 .enumerate()
889 .filter_map(move |(row_index, data_row)| {
890 let get_string = |name: &str| -> Option<String> {
891 let index = column_indices.get(name).copied()?;
892 match data_row.get(index)? {
893 JsonValue::String(s) => Some(s.clone()),
894 JsonValue::Null => None,
895 other => Some(other.to_string()),
896 }
897 };
898
899 let get_json = |name: &str| -> Option<JsonValue> {
900 let index = column_indices.get(name).copied()?;
901 let value = data_row.get(index)?;
902 if value.is_null() {
903 return None;
904 }
905 match value {
906 JsonValue::String(s) => serde_json::from_str(s).ok(),
907 other => Some(other.clone()),
908 }
909 };
910
911 let inputs_json = get_json("inputs");
912 let inputs: Option<ZetaPromptInput> = match &inputs_json {
913 Some(v) => match serde_json::from_value(v.clone()) {
914 Ok(parsed) => Some(parsed),
915 Err(e) => {
916 log::warn!(
917 "skipping row {row_index}: failed to parse inputs - {e}",
918 );
919 return None;
920 }
921 },
922 None => None,
923 };
924 let output = get_string("output");
925 let rating = get_string("rating");
926 let feedback = get_string("feedback").unwrap_or_default();
927 let device_id = get_string("device_id");
928 let time = get_string("time");
929
930 match (inputs, output.clone(), rating.clone(), device_id.clone(), time.clone()) {
931 (Some(inputs), Some(output), Some(rating), Some(device_id), Some(time)) => {
932 Some(build_rated_example(
933 device_id,
934 time,
935 inputs,
936 output,
937 rating,
938 feedback,
939 ))
940 }
941 _ => {
942 log::warn!(
943 "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} device_id={:?} time={:?}",
944 inputs_json.is_some(),
945 output.is_some(),
946 rating.is_some(),
947 device_id.is_some(),
948 time.is_some(),
949 );
950 None
951 }
952 }
953 });
954
955 Ok(iter)
956}
957
958fn build_rated_example(
959 device_id: String,
960 time: String,
961 input: ZetaPromptInput,
962 output: String,
963 rating: String,
964 feedback: String,
965) -> Example {
966 let parsed_rating = if rating == "Positive" {
967 EditPredictionRating::Positive
968 } else {
969 EditPredictionRating::Negative
970 };
971 let is_positive = parsed_rating == EditPredictionRating::Positive;
972 let request_id = format!("rated-{}-{}", device_id, time);
973
974 let tags = if is_positive {
975 vec!["rated:positive".to_string()]
976 } else {
977 vec!["rated:negative".to_string()]
978 };
979
980 let mut example = build_example_from_snowflake(request_id, device_id, time, input, tags, None);
981
982 example.spec.rating = Some(parsed_rating);
983
984 if !feedback.is_empty() {
985 example
986 .spec
987 .human_feedback
988 .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
989 }
990
991 if is_positive {
992 example.spec.expected_patches = vec![output];
993 } else {
994 example.spec.rejected_patch = Some(output);
995 }
996
997 example
998}
999
1000fn requested_examples_from_response<'a>(
1001 response: &'a SnowflakeStatementResponse,
1002 column_indices: &'a std::collections::HashMap<String, usize>,
1003) -> Result<impl Iterator<Item = Example> + 'a> {
1004 if let Some(code) = &response.code {
1005 if code != SNOWFLAKE_SUCCESS_CODE {
1006 anyhow::bail!(
1007 "snowflake sql api returned error code={code} message={}",
1008 response.message.as_deref().unwrap_or("<no message>")
1009 );
1010 }
1011 }
1012
1013 let iter = response
1014 .data
1015 .iter()
1016 .enumerate()
1017 .filter_map(move |(row_index, data_row)| {
1018 let get_string = |name: &str| -> Option<String> {
1019 let index = column_indices.get(name).copied()?;
1020 match data_row.get(index)? {
1021 JsonValue::String(s) => Some(s.clone()),
1022 JsonValue::Null => None,
1023 other => Some(other.to_string()),
1024 }
1025 };
1026
1027 let get_json = |name: &str| -> Option<JsonValue> {
1028 let index = column_indices.get(name).copied()?;
1029 let value = data_row.get(index)?;
1030 if value.is_null() {
1031 return None;
1032 }
1033 match value {
1034 JsonValue::String(s) => serde_json::from_str(s).ok(),
1035 other => Some(other.clone()),
1036 }
1037 };
1038
1039 let request_id_str = get_string("request_id");
1040 let device_id = get_string("device_id");
1041 let time = get_string("time");
1042 let input_json = get_json("input");
1043 let input: Option<ZetaPromptInput> =
1044 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1045
1046 match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
1047 (Some(request_id), Some(device_id), Some(time), Some(input)) => {
1048 Some(build_example_from_snowflake(
1049 request_id,
1050 device_id,
1051 time,
1052 input,
1053 vec!["requested".to_string()],
1054 None,
1055 ))
1056 }
1057 _ => {
1058 log::warn!(
1059 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
1060 request_id_str.is_some(),
1061 device_id.is_some(),
1062 time.is_some(),
1063 input_json.is_some(),
1064 );
1065 None
1066 }
1067 }
1068 });
1069
1070 Ok(iter)
1071}
1072
1073fn rejected_examples_from_response<'a>(
1074 response: &'a SnowflakeStatementResponse,
1075 column_indices: &'a std::collections::HashMap<String, usize>,
1076) -> Result<impl Iterator<Item = Example> + 'a> {
1077 if let Some(code) = &response.code {
1078 if code != SNOWFLAKE_SUCCESS_CODE {
1079 anyhow::bail!(
1080 "snowflake sql api returned error code={code} message={}",
1081 response.message.as_deref().unwrap_or("<no message>")
1082 );
1083 }
1084 }
1085
1086 let iter = response
1087 .data
1088 .iter()
1089 .enumerate()
1090 .filter_map(move |(row_index, data_row)| {
1091 let get_string = |name: &str| -> Option<String> {
1092 let index = column_indices.get(name).copied()?;
1093 match data_row.get(index)? {
1094 JsonValue::String(s) => Some(s.clone()),
1095 JsonValue::Null => None,
1096 other => Some(other.to_string()),
1097 }
1098 };
1099
1100 let get_json = |name: &str| -> Option<JsonValue> {
1101 let index = column_indices.get(name).copied()?;
1102 let value = data_row.get(index)?;
1103 if value.is_null() {
1104 return None;
1105 }
1106 match value {
1107 JsonValue::String(s) => serde_json::from_str(s).ok(),
1108 other => Some(other.clone()),
1109 }
1110 };
1111
1112 let get_bool = |name: &str| -> Option<bool> {
1113 let index = column_indices.get(name).copied()?;
1114 match data_row.get(index)? {
1115 JsonValue::Bool(b) => Some(*b),
1116 JsonValue::String(s) => s.parse().ok(),
1117 _ => None,
1118 }
1119 };
1120
1121 let request_id_str = get_string("request_id");
1122 let device_id = get_string("device_id");
1123 let time = get_string("time");
1124 let input_json = get_json("input");
1125 let input: Option<ZetaPromptInput> =
1126 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1127 let output = get_string("output");
1128 let was_shown = get_bool("was_shown");
1129 let reason = get_string("reason");
1130
1131 match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1132 (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1133 Some(build_rejected_example(
1134 request_id,
1135 device_id,
1136 time,
1137 input,
1138 output,
1139 was_shown,
1140 reason,
1141 ))
1142 }
1143 _ => {
1144 log::warn!(
1145 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1146 request_id_str.is_some(),
1147 device_id.is_some(),
1148 time.is_some(),
1149 input_json.is_some(),
1150 output.is_some(),
1151 was_shown.is_some(),
1152 reason.is_some()
1153 );
1154 None
1155 }
1156 }
1157 });
1158
1159 Ok(iter)
1160}
1161
1162fn build_rejected_example(
1163 request_id: String,
1164 device_id: String,
1165 time: String,
1166 input: ZetaPromptInput,
1167 output: String,
1168 was_shown: bool,
1169 reason: String,
1170) -> Example {
1171 let rejected_patch = build_output_patch(
1172 &input.cursor_path,
1173 input.cursor_excerpt.as_ref(),
1174 &input.editable_range_in_excerpt,
1175 &output,
1176 );
1177 let mut example = build_example_from_snowflake(
1178 request_id,
1179 device_id,
1180 time,
1181 input,
1182 vec![format!("rejection:{}", reason.to_lowercase())],
1183 Some(RejectionInfo { reason, was_shown }),
1184 );
1185 example.spec.rejected_patch = Some(rejected_patch);
1186 example
1187}
1188
1189struct RejectionInfo {
1190 reason: String,
1191 was_shown: bool,
1192}
1193
1194fn build_example_from_snowflake(
1195 request_id: String,
1196 device_id: String,
1197 time: String,
1198 input: ZetaPromptInput,
1199 tags: Vec<String>,
1200 rejection: Option<RejectionInfo>,
1201) -> Example {
1202 let events: Vec<CapturedEvent> = input
1203 .events
1204 .iter()
1205 .map(|event| match event.as_ref() {
1206 zeta_prompt::Event::BufferChange {
1207 path,
1208 old_path,
1209 diff,
1210 predicted,
1211 in_open_source_repo,
1212 } => CapturedEvent {
1213 path: path.clone(),
1214 old_path: old_path.clone(),
1215 diff: diff.clone(),
1216 predicted: *predicted,
1217 in_open_source_repo: *in_open_source_repo,
1218 },
1219 })
1220 .collect();
1221
1222 let related_files: Vec<CapturedRelatedFile> = input
1223 .related_files
1224 .iter()
1225 .map(|rf| CapturedRelatedFile {
1226 path: rf.path.clone(),
1227 max_row: rf.max_row,
1228 excerpts: rf
1229 .excerpts
1230 .iter()
1231 .map(|e| CapturedRelatedExcerpt {
1232 row_range: e.row_range.clone(),
1233 text: e.text.to_string(),
1234 })
1235 .collect(),
1236 })
1237 .collect();
1238
1239 let cursor_excerpt = input.cursor_excerpt.as_ref();
1240 let cursor_offset = input.cursor_offset_in_excerpt;
1241
1242 let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset);
1243
1244 let mut edit_history = String::new();
1245 for event in &input.events {
1246 zeta_prompt::write_event(&mut edit_history, event);
1247 edit_history.push('\n');
1248 }
1249
1250 let (rejection_reason, was_shown) = match &rejection {
1251 Some(r) => (r.reason.clone(), r.was_shown),
1252 None => (String::new(), false),
1253 };
1254
1255 let spec = ExampleSpec {
1256 name: request_id.clone(),
1257 repository_url: String::new(),
1258 revision: String::new(),
1259 tags,
1260 reasoning: None,
1261 uncommitted_diff: String::new(),
1262 cursor_path: input.cursor_path.clone(),
1263 cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1264 edit_history,
1265 expected_patches: Vec::new(),
1266 rejected_patch: None,
1267 captured_prompt_input: Some(CapturedPromptInput {
1268 cursor_file_content: cursor_excerpt.to_string(),
1269 cursor_offset,
1270 cursor_row,
1271 cursor_column,
1272 excerpt_start_row: None,
1273 events,
1274 related_files,
1275 }),
1276 telemetry: Some(TelemetrySource {
1277 request_id,
1278 device_id,
1279 time,
1280 rejection_reason,
1281 was_shown,
1282 }),
1283 human_feedback: Vec::new(),
1284 rating: None,
1285 };
1286
1287 Example {
1288 spec,
1289 prompt_inputs: None,
1290 prompt: None,
1291 predictions: Vec::new(),
1292 score: Vec::new(),
1293 qa: Vec::new(),
1294 state: None,
1295 }
1296}
1297
1298fn compute_row_column(text: &str, offset: usize) -> (u32, u32) {
1299 let mut row = 0u32;
1300 let mut last_newline_offset = 0;
1301 for (i, c) in text.char_indices() {
1302 if i >= offset {
1303 break;
1304 }
1305 if c == '\n' {
1306 row += 1;
1307 last_newline_offset = i + 1;
1308 }
1309 }
1310 let column = (offset - last_newline_offset) as u32;
1311 (row, column)
1312}
1313
1314fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1315 let before = &excerpt[..cursor_offset.min(excerpt.len())];
1316 let after = &excerpt[cursor_offset.min(excerpt.len())..];
1317 format!("{}[CURSOR_POSITION]{}", before, after)
1318}
1319
1320fn build_output_patch(
1321 cursor_path: &std::path::Path,
1322 cursor_excerpt: &str,
1323 editable_range: &std::ops::Range<usize>,
1324 model_output: &str,
1325) -> String {
1326 let old_text = &cursor_excerpt[editable_range.clone()];
1327
1328 let editable_start_row = cursor_excerpt[..editable_range.start]
1329 .chars()
1330 .filter(|&c| c == '\n')
1331 .count() as u32;
1332
1333 let diff_body = language::unified_diff_with_offsets(
1334 old_text,
1335 model_output,
1336 editable_start_row,
1337 editable_start_row,
1338 );
1339
1340 let mut patch = String::new();
1341 writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1342 writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1343 patch.push_str(&diff_body);
1344 patch
1345}
1346
1347fn get_column_indices(
1348 meta: &Option<SnowflakeResultSetMetaData>,
1349 names: &[&str],
1350) -> std::collections::HashMap<String, usize> {
1351 let mut indices = std::collections::HashMap::new();
1352 if let Some(meta) = meta {
1353 for (index, col) in meta.row_type.iter().enumerate() {
1354 for &name in names {
1355 if col.name.eq_ignore_ascii_case(name) {
1356 indices.insert(name.to_string(), index);
1357 }
1358 }
1359 }
1360 }
1361 indices
1362}