1use anyhow::{Context as _, Result};
2use flate2::read::GzDecoder;
3use gpui::BackgroundExecutor;
4use http_client::{AsyncBody, HttpClient, Method, Request};
5use indoc::indoc;
6use serde::Deserialize;
7use serde_json::{Value as JsonValue, json};
8use std::io::Read;
9use std::sync::Arc;
10use std::time::Duration;
11use telemetry_events::EditPredictionRating;
12
13use zeta_prompt::ZetaPromptInput;
14
15use crate::example::Example;
16use crate::progress::{InfoStyle, Progress, Step};
17use crate::sync_deployments::EDIT_PREDICTION_DEPLOYMENT_EVENT;
18use edit_prediction::example_spec::{
19 CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec,
20 TelemetrySource,
21};
22use std::fmt::Write as _;
23
24pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
25pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
26const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
27const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
28const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
29const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
30
31const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
32pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
33pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
34
35/// Parse an input token of the form `captured-after:{timestamp}`.
36pub fn parse_captured_after_input(input: &str) -> Option<&str> {
37 input.strip_prefix("captured-after:")
38}
39
40/// Parse an input token of the form `rejected-after:{timestamp}`.
41pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
42 input.strip_prefix("rejected-after:")
43}
44
45/// Parse an input token of the form `requested-after:{timestamp}`.
46pub fn parse_requested_after_input(input: &str) -> Option<&str> {
47 input.strip_prefix("requested-after:")
48}
49
50/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
51/// or `rated-negative-after:{timestamp}`.
52/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
53pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
54 if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
55 Some((timestamp, Some(EditPredictionRating::Positive)))
56 } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
57 Some((timestamp, Some(EditPredictionRating::Negative)))
58 } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
59 Some((timestamp, None))
60 } else {
61 None
62 }
63}
64
65pub async fn fetch_captured_examples_after(
66 http_client: Arc<dyn HttpClient>,
67 after_timestamps: &[String],
68 max_rows_per_timestamp: usize,
69 background_executor: BackgroundExecutor,
70) -> Result<Vec<Example>> {
71 if after_timestamps.is_empty() {
72 return Ok(Vec::new());
73 }
74
75 let progress = Progress::global();
76
77 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
78 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
79 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
80 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
81 )?;
82 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
83
84 let mut all_examples = Vec::new();
85
86 for after_date in after_timestamps.iter() {
87 let step_progress_name = format!(">{after_date}");
88 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
89 step_progress.set_substatus("querying");
90
91 let statement = indoc! {r#"
92 SELECT
93 event_properties:example AS example
94 FROM events
95 WHERE event_type = ?
96 AND time > TRY_TO_TIMESTAMP_NTZ(?)
97 ORDER BY time ASC
98 LIMIT ?
99 "#};
100
101 let request = json!({
102 "statement": statement,
103 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
104 "database": "EVENTS",
105 "schema": "PUBLIC",
106 "warehouse": "DBT",
107 "role": role,
108 "bindings": {
109 "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
110 "2": { "type": "TEXT", "value": after_date },
111 "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
112 }
113 });
114
115 let response = run_sql_with_polling(
116 http_client.clone(),
117 &base_url,
118 &token,
119 &request,
120 &step_progress,
121 background_executor.clone(),
122 )
123 .await?;
124
125 let total_rows = response
126 .result_set_meta_data
127 .as_ref()
128 .and_then(|m| m.num_rows)
129 .unwrap_or(response.data.len() as i64);
130
131 let num_partitions = response
132 .result_set_meta_data
133 .as_ref()
134 .map(|m| m.partition_info.len())
135 .unwrap_or(1)
136 .max(1);
137
138 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
139 step_progress.set_substatus("parsing");
140
141 let example_index = response
142 .result_set_meta_data
143 .as_ref()
144 .and_then(|m| {
145 m.row_type.iter().enumerate().find_map(|(index, col)| {
146 if col.name.eq_ignore_ascii_case("example") {
147 Some(index)
148 } else {
149 None
150 }
151 })
152 })
153 .unwrap_or(0);
154
155 all_examples.extend(examples_from_response(&response, example_index)?);
156
157 if num_partitions > 1 {
158 let statement_handle = response
159 .statement_handle
160 .as_ref()
161 .context("response has multiple partitions but no statementHandle")?;
162
163 for partition in 1..num_partitions {
164 step_progress.set_substatus(format!(
165 "fetching partition {}/{}",
166 partition + 1,
167 num_partitions
168 ));
169
170 let partition_response = fetch_partition(
171 http_client.clone(),
172 &base_url,
173 &token,
174 statement_handle,
175 partition,
176 )
177 .await?;
178
179 all_examples.extend(examples_from_response(&partition_response, example_index)?);
180 }
181 }
182
183 step_progress.set_substatus("done");
184 }
185
186 Ok(all_examples)
187}
188
189#[derive(Debug, Clone, Deserialize)]
190#[serde(rename_all = "camelCase")]
191pub(crate) struct SnowflakeStatementResponse {
192 #[serde(default)]
193 pub(crate) data: Vec<Vec<JsonValue>>,
194 #[serde(default)]
195 pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
196 #[serde(default)]
197 pub(crate) code: Option<String>,
198 #[serde(default)]
199 pub(crate) message: Option<String>,
200 #[serde(default)]
201 pub(crate) statement_handle: Option<String>,
202}
203
204#[derive(Debug, Clone, Deserialize)]
205#[serde(rename_all = "camelCase")]
206pub(crate) struct SnowflakeResultSetMetaData {
207 #[serde(default, rename = "rowType")]
208 row_type: Vec<SnowflakeColumnMeta>,
209 #[serde(default)]
210 num_rows: Option<i64>,
211 #[serde(default)]
212 partition_info: Vec<SnowflakePartitionInfo>,
213}
214
215#[derive(Debug, Clone, Deserialize)]
216#[serde(rename_all = "camelCase")]
217struct SnowflakePartitionInfo {}
218
219#[derive(Debug, Clone, Deserialize)]
220struct SnowflakeColumnMeta {
221 #[serde(default)]
222 name: String,
223}
224
225fn examples_from_response(
226 response: &SnowflakeStatementResponse,
227 example_index: usize,
228) -> Result<impl Iterator<Item = Example> + '_> {
229 if let Some(code) = &response.code {
230 if code != SNOWFLAKE_SUCCESS_CODE {
231 anyhow::bail!(
232 "snowflake sql api returned error code={code} message={}",
233 response.message.as_deref().unwrap_or("<no message>")
234 );
235 }
236 }
237
238 let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
239 let Some(example_value) = data_row.get(example_index) else {
240 return None;
241 };
242 if example_value.is_null() {
243 return None;
244 }
245
246 let parse_result = match example_value {
247 JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
248 _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
249 };
250
251 match parse_result {
252 Ok(spec) => Some(Example {
253 spec,
254 prompt_inputs: None,
255 prompt: None,
256 predictions: Vec::new(),
257 score: Vec::new(),
258 qa: Vec::new(),
259 state: None,
260 }),
261 Err(error) => {
262 let raw_json = serde_json::to_string_pretty(example_value)
263 .unwrap_or_else(|_| "<failed to serialize json>".to_string());
264 log::error!(
265 "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
266 );
267 None
268 }
269 }
270 });
271
272 Ok(iter)
273}
274
275async fn run_sql_with_polling(
276 http_client: Arc<dyn HttpClient>,
277 base_url: &str,
278 token: &str,
279 request: &serde_json::Value,
280 step_progress: &crate::progress::StepProgress,
281 background_executor: BackgroundExecutor,
282) -> Result<SnowflakeStatementResponse> {
283 let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
284
285 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
286 let statement_handle = response
287 .statement_handle
288 .as_ref()
289 .context("async query response missing statementHandle")?
290 .clone();
291
292 for attempt in 1..=MAX_POLL_ATTEMPTS {
293 step_progress.set_substatus(format!("polling ({attempt})"));
294
295 background_executor.timer(POLL_INTERVAL).await;
296
297 response =
298 fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
299
300 if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
301 break;
302 }
303 }
304
305 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
306 anyhow::bail!(
307 "query still running after {} poll attempts ({} seconds)",
308 MAX_POLL_ATTEMPTS,
309 MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
310 );
311 }
312 }
313
314 Ok(response)
315}
316
317pub(crate) async fn fetch_partition(
318 http_client: Arc<dyn HttpClient>,
319 base_url: &str,
320 token: &str,
321 statement_handle: &str,
322 partition: usize,
323) -> Result<SnowflakeStatementResponse> {
324 let url = format!(
325 "{}/api/v2/statements/{}?partition={}",
326 base_url.trim_end_matches('/'),
327 statement_handle,
328 partition
329 );
330
331 let http_request = Request::builder()
332 .method(Method::GET)
333 .uri(url.as_str())
334 .header("Authorization", format!("Bearer {token}"))
335 .header(
336 "X-Snowflake-Authorization-Token-Type",
337 "PROGRAMMATIC_ACCESS_TOKEN",
338 )
339 .header("Accept", "application/json")
340 .header("Accept-Encoding", "gzip")
341 .header("User-Agent", "edit_prediction_cli")
342 .body(AsyncBody::empty())?;
343
344 let response = http_client
345 .send(http_request)
346 .await
347 .context("failed to send partition request to Snowflake SQL API")?;
348
349 let status = response.status();
350 let content_encoding = response
351 .headers()
352 .get("content-encoding")
353 .and_then(|v| v.to_str().ok())
354 .map(|s| s.to_lowercase());
355
356 let body_bytes = {
357 use futures::AsyncReadExt as _;
358
359 let mut body = response.into_body();
360 let mut bytes = Vec::new();
361 body.read_to_end(&mut bytes)
362 .await
363 .context("failed to read Snowflake SQL API partition response body")?;
364 bytes
365 };
366
367 let body_bytes = if content_encoding.as_deref() == Some("gzip") {
368 let mut decoder = GzDecoder::new(&body_bytes[..]);
369 let mut decompressed = Vec::new();
370 decoder
371 .read_to_end(&mut decompressed)
372 .context("failed to decompress gzip response")?;
373 decompressed
374 } else {
375 body_bytes
376 };
377
378 if !status.is_success() && status.as_u16() != 202 {
379 let body_text = String::from_utf8_lossy(&body_bytes);
380 anyhow::bail!(
381 "snowflake sql api partition request http {}: {}",
382 status.as_u16(),
383 body_text
384 );
385 }
386
387 if body_bytes.is_empty() {
388 anyhow::bail!(
389 "snowflake sql api partition {} returned empty response body (http {})",
390 partition,
391 status.as_u16()
392 );
393 }
394
395 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
396 let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
397 format!(
398 "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
399 partition,
400 status.as_u16(),
401 body_preview
402 )
403 })
404}
405
406pub(crate) async fn run_sql(
407 http_client: Arc<dyn HttpClient>,
408 base_url: &str,
409 token: &str,
410 request: &serde_json::Value,
411) -> Result<SnowflakeStatementResponse> {
412 let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
413
414 let request_body =
415 serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
416
417 let http_request = Request::builder()
418 .method(Method::POST)
419 .uri(url.as_str())
420 .header("Authorization", format!("Bearer {token}"))
421 .header(
422 "X-Snowflake-Authorization-Token-Type",
423 "PROGRAMMATIC_ACCESS_TOKEN",
424 )
425 .header("Content-Type", "application/json")
426 .header("Accept", "application/json")
427 .header("User-Agent", "edit_prediction_cli")
428 .body(AsyncBody::from(request_body.clone()))?;
429
430 let response = http_client
431 .send(http_request)
432 .await
433 .context("failed to send request to Snowflake SQL API")?;
434
435 let status = response.status();
436 let body_bytes = {
437 use futures::AsyncReadExt as _;
438
439 let mut body = response.into_body();
440 let mut bytes = Vec::new();
441 body.read_to_end(&mut bytes)
442 .await
443 .context("failed to read Snowflake SQL API response body")?;
444 bytes
445 };
446
447 if !status.is_success() && status.as_u16() != 202 {
448 let body_text = String::from_utf8_lossy(&body_bytes);
449 anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
450 }
451
452 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
453 .context("failed to parse Snowflake SQL API response JSON")
454}
455
456pub async fn fetch_rejected_examples_after(
457 http_client: Arc<dyn HttpClient>,
458 after_timestamps: &[String],
459 max_rows_per_timestamp: usize,
460 background_executor: BackgroundExecutor,
461) -> Result<Vec<Example>> {
462 if after_timestamps.is_empty() {
463 return Ok(Vec::new());
464 }
465
466 let progress = Progress::global();
467
468 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
469 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
470 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
471 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
472 )?;
473 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
474
475 let mut all_examples = Vec::new();
476
477 for after_date in after_timestamps.iter() {
478 let step_progress_name = format!("rejected>{after_date}");
479 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
480 step_progress.set_substatus("querying");
481
482 // Join rejected events with their corresponding request events to get the full context.
483 // We filter for V3 sampling data which contains the structured input we need.
484 // We also filter for predictions that were actually shown to the user (was_shown = true)
485 // to focus on explicit user rejections rather than implicit cancellations.
486 let statement = indoc! {r#"
487 SELECT
488 req.event_properties:request_id::string AS request_id,
489 req.device_id::string AS device_id,
490 req.time::string AS time,
491 req.event_properties:input AS input,
492 req.event_properties:prompt::string AS prompt,
493 req.event_properties:output::string AS output,
494 rej.event_properties:was_shown::boolean AS was_shown,
495 rej.event_properties:reason::string AS reason
496 FROM events req
497 INNER JOIN events rej
498 ON req.event_properties:request_id = rej.event_properties:request_id
499 WHERE req.event_type = ?
500 AND rej.event_type = ?
501 AND req.event_properties:version = 'V3'
502 AND rej.event_properties:was_shown = true
503 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
504 ORDER BY req.time ASC
505 LIMIT ?
506 "#};
507
508 let request = json!({
509 "statement": statement,
510 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
511 "database": "EVENTS",
512 "schema": "PUBLIC",
513 "warehouse": "DBT",
514 "role": role,
515 "bindings": {
516 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
517 "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
518 "3": { "type": "TEXT", "value": after_date },
519 "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
520 }
521 });
522
523 let response = run_sql_with_polling(
524 http_client.clone(),
525 &base_url,
526 &token,
527 &request,
528 &step_progress,
529 background_executor.clone(),
530 )
531 .await?;
532
533 let total_rows = response
534 .result_set_meta_data
535 .as_ref()
536 .and_then(|m| m.num_rows)
537 .unwrap_or(response.data.len() as i64);
538
539 let num_partitions = response
540 .result_set_meta_data
541 .as_ref()
542 .map(|m| m.partition_info.len())
543 .unwrap_or(1)
544 .max(1);
545
546 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
547 step_progress.set_substatus("parsing");
548
549 let column_indices = get_column_indices(
550 &response.result_set_meta_data,
551 &[
552 "request_id",
553 "device_id",
554 "time",
555 "input",
556 "prompt",
557 "output",
558 "was_shown",
559 "reason",
560 ],
561 );
562
563 all_examples.extend(rejected_examples_from_response(&response, &column_indices)?);
564
565 if num_partitions > 1 {
566 let statement_handle = response
567 .statement_handle
568 .as_ref()
569 .context("response has multiple partitions but no statementHandle")?;
570
571 for partition in 1..num_partitions {
572 step_progress.set_substatus(format!(
573 "fetching partition {}/{}",
574 partition + 1,
575 num_partitions
576 ));
577
578 let partition_response = fetch_partition(
579 http_client.clone(),
580 &base_url,
581 &token,
582 statement_handle,
583 partition,
584 )
585 .await?;
586
587 all_examples.extend(rejected_examples_from_response(
588 &partition_response,
589 &column_indices,
590 )?);
591 }
592 }
593
594 step_progress.set_substatus("done");
595 }
596
597 Ok(all_examples)
598}
599
600pub async fn fetch_requested_examples_after(
601 http_client: Arc<dyn HttpClient>,
602 after_timestamps: &[String],
603 max_rows_per_timestamp: usize,
604 background_executor: BackgroundExecutor,
605) -> Result<Vec<Example>> {
606 if after_timestamps.is_empty() {
607 return Ok(Vec::new());
608 }
609
610 let progress = Progress::global();
611
612 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
613 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
614 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
615 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
616 )?;
617 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
618
619 let mut all_examples = Vec::new();
620
621 for after_date in after_timestamps.iter() {
622 let step_progress_name = format!("requested>{after_date}");
623 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
624 step_progress.set_substatus("querying");
625
626 let statement = indoc! {r#"
627 SELECT
628 req.event_properties:request_id::string AS request_id,
629 req.device_id::string AS device_id,
630 req.time::string AS time,
631 req.event_properties:input AS input
632 FROM events req
633 WHERE req.event_type = ?
634 AND req.event_properties:version = 'V3'
635 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
636 ORDER BY req.time ASC
637 LIMIT ?
638 "#};
639
640 let request = json!({
641 "statement": statement,
642 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
643 "database": "EVENTS",
644 "schema": "PUBLIC",
645 "warehouse": "DBT",
646 "role": role,
647 "bindings": {
648 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
649 "2": { "type": "TEXT", "value": after_date },
650 "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
651 }
652 });
653
654 let response = run_sql_with_polling(
655 http_client.clone(),
656 &base_url,
657 &token,
658 &request,
659 &step_progress,
660 background_executor.clone(),
661 )
662 .await?;
663
664 let total_rows = response
665 .result_set_meta_data
666 .as_ref()
667 .and_then(|m| m.num_rows)
668 .unwrap_or(response.data.len() as i64);
669
670 let num_partitions = response
671 .result_set_meta_data
672 .as_ref()
673 .map(|m| m.partition_info.len())
674 .unwrap_or(1)
675 .max(1);
676
677 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
678 step_progress.set_substatus("parsing");
679
680 let column_indices = get_column_indices(
681 &response.result_set_meta_data,
682 &["request_id", "device_id", "time", "input"],
683 );
684
685 all_examples.extend(requested_examples_from_response(
686 &response,
687 &column_indices,
688 )?);
689
690 if num_partitions > 1 {
691 let statement_handle = response
692 .statement_handle
693 .as_ref()
694 .context("response has multiple partitions but no statementHandle")?;
695
696 for partition in 1..num_partitions {
697 step_progress.set_substatus(format!(
698 "fetching partition {}/{}",
699 partition + 1,
700 num_partitions
701 ));
702
703 let partition_response = fetch_partition(
704 http_client.clone(),
705 &base_url,
706 &token,
707 statement_handle,
708 partition,
709 )
710 .await?;
711
712 all_examples.extend(requested_examples_from_response(
713 &partition_response,
714 &column_indices,
715 )?);
716 }
717 }
718
719 step_progress.set_substatus("done");
720 }
721
722 Ok(all_examples)
723}
724
725pub async fn fetch_rated_examples_after(
726 http_client: Arc<dyn HttpClient>,
727 inputs: &[(String, Option<EditPredictionRating>)],
728 max_rows_per_timestamp: usize,
729 background_executor: BackgroundExecutor,
730) -> Result<Vec<Example>> {
731 if inputs.is_empty() {
732 return Ok(Vec::new());
733 }
734
735 let progress = Progress::global();
736
737 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
738 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
739 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
740 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
741 )?;
742 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
743
744 let mut all_examples = Vec::new();
745
746 for (after_date, rating_filter) in inputs.iter() {
747 let filter_label = match rating_filter {
748 None => "",
749 Some(EditPredictionRating::Positive) => ":positive",
750 Some(EditPredictionRating::Negative) => ":negative",
751 };
752 let step_progress_name = format!("rated{filter_label}>{after_date}");
753 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
754 step_progress.set_substatus("querying");
755
756 let rating_value = rating_filter.as_ref().map(|r| match r {
757 EditPredictionRating::Positive => "Positive",
758 EditPredictionRating::Negative => "Negative",
759 });
760
761 let statement = indoc! {r#"
762 SELECT
763 rated.event_properties:request_id::string AS request_id,
764 rated.event_properties:inputs AS inputs,
765 rated.event_properties:output::string AS output,
766 rated.event_properties:rating::string AS rating,
767 rated.event_properties:feedback::string AS feedback,
768 rated.device_id::string AS device_id,
769 rated.time::string AS time,
770 deploy.event_properties:experiment_name::string AS experiment_name,
771 deploy.event_properties:environment::string AS environment
772 FROM events rated
773 LEFT JOIN events req
774 ON rated.event_properties:request_id::string = req.event_properties:request_id::string
775 AND req.event_type = ?
776 LEFT JOIN events deploy
777 ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string
778 AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string
779 AND deploy.event_type = ?
780 WHERE rated.event_type = ?
781 AND (? IS NULL OR rated.event_properties:rating::string = ?)
782 AND rated.time > TRY_TO_TIMESTAMP_NTZ(?)
783 AND rated.event_properties:inputs IS NOT NULL
784 AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
785 AND rated.event_properties:output IS NOT NULL
786 ORDER BY rated.time ASC
787 LIMIT ?
788 "#};
789
790 let bindings = json!({
791 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
792 "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
793 "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
794 "4": { "type": "TEXT", "value": rating_value },
795 "5": { "type": "TEXT", "value": rating_value },
796 "6": { "type": "TEXT", "value": after_date },
797 "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
798 });
799
800 let request = json!({
801 "statement": statement,
802 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
803 "database": "EVENTS",
804 "schema": "PUBLIC",
805 "warehouse": "DBT",
806 "role": role,
807 "bindings": bindings
808 });
809
810 let response = run_sql_with_polling(
811 http_client.clone(),
812 &base_url,
813 &token,
814 &request,
815 &step_progress,
816 background_executor.clone(),
817 )
818 .await?;
819
820 let total_rows = response
821 .result_set_meta_data
822 .as_ref()
823 .and_then(|m| m.num_rows)
824 .unwrap_or(response.data.len() as i64);
825
826 let num_partitions = response
827 .result_set_meta_data
828 .as_ref()
829 .map(|m| m.partition_info.len())
830 .unwrap_or(1)
831 .max(1);
832
833 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
834 step_progress.set_substatus("parsing");
835
836 let column_indices = get_column_indices(
837 &response.result_set_meta_data,
838 &[
839 "request_id",
840 "inputs",
841 "output",
842 "rating",
843 "feedback",
844 "device_id",
845 "time",
846 "experiment_name",
847 "environment",
848 ],
849 );
850
851 all_examples.extend(rated_examples_from_response(&response, &column_indices)?);
852
853 if num_partitions > 1 {
854 let statement_handle = response
855 .statement_handle
856 .as_ref()
857 .context("response has multiple partitions but no statementHandle")?;
858
859 for partition in 1..num_partitions {
860 step_progress.set_substatus(format!(
861 "fetching partition {}/{}",
862 partition + 1,
863 num_partitions
864 ));
865
866 let partition_response = fetch_partition(
867 http_client.clone(),
868 &base_url,
869 &token,
870 statement_handle,
871 partition,
872 )
873 .await?;
874
875 all_examples.extend(rated_examples_from_response(
876 &partition_response,
877 &column_indices,
878 )?);
879 }
880 }
881
882 step_progress.set_substatus("done");
883 }
884
885 Ok(all_examples)
886}
887
888fn rated_examples_from_response<'a>(
889 response: &'a SnowflakeStatementResponse,
890 column_indices: &'a std::collections::HashMap<String, usize>,
891) -> Result<impl Iterator<Item = Example> + 'a> {
892 if let Some(code) = &response.code {
893 if code != SNOWFLAKE_SUCCESS_CODE {
894 anyhow::bail!(
895 "snowflake sql api returned error code={code} message={}",
896 response.message.as_deref().unwrap_or("<no message>")
897 );
898 }
899 }
900
901 let iter = response
902 .data
903 .iter()
904 .enumerate()
905 .filter_map(move |(row_index, data_row)| {
906 let get_string = |name: &str| -> Option<String> {
907 let index = column_indices.get(name).copied()?;
908 match data_row.get(index)? {
909 JsonValue::String(s) => Some(s.clone()),
910 JsonValue::Null => None,
911 other => Some(other.to_string()),
912 }
913 };
914
915 let get_json = |name: &str| -> Option<JsonValue> {
916 let index = column_indices.get(name).copied()?;
917 let value = data_row.get(index)?;
918 if value.is_null() {
919 return None;
920 }
921 match value {
922 JsonValue::String(s) => serde_json::from_str(s).ok(),
923 other => Some(other.clone()),
924 }
925 };
926
927 let request_id = get_string("request_id");
928 let inputs_json = get_json("inputs");
929 let inputs: Option<ZetaPromptInput> = match &inputs_json {
930 Some(v) => match serde_json::from_value(v.clone()) {
931 Ok(parsed) => Some(parsed),
932 Err(e) => {
933 log::warn!(
934 "skipping row {row_index}: failed to parse inputs - {e}",
935 );
936 return None;
937 }
938 },
939 None => None,
940 };
941 let output = get_string("output");
942 let rating = get_string("rating");
943 let feedback = get_string("feedback").unwrap_or_default();
944 let device_id = get_string("device_id");
945 let time = get_string("time");
946 let experiment_name = get_string("experiment_name");
947 let environment = get_string("environment");
948
949 match (inputs, output.clone(), rating.clone(), device_id.clone(), time.clone()) {
950 (Some(inputs), Some(output), Some(rating), Some(device_id), Some(time)) => {
951 Some(build_rated_example(
952 request_id,
953 device_id,
954 time,
955 inputs,
956 output,
957 rating,
958 feedback,
959 experiment_name,
960 environment,
961 ))
962 }
963 _ => {
964 log::warn!(
965 "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} device_id={:?} time={:?}",
966 inputs_json.is_some(),
967 output.is_some(),
968 rating.is_some(),
969 device_id.is_some(),
970 time.is_some(),
971 );
972 None
973 }
974 }
975 });
976
977 Ok(iter)
978}
979
980fn build_rated_example(
981 request_id: Option<String>,
982 device_id: String,
983 time: String,
984 input: ZetaPromptInput,
985 output: String,
986 rating: String,
987 feedback: String,
988 experiment_name: Option<String>,
989 environment: Option<String>,
990) -> Example {
991 let parsed_rating = if rating == "Positive" {
992 EditPredictionRating::Positive
993 } else {
994 EditPredictionRating::Negative
995 };
996 let is_positive = parsed_rating == EditPredictionRating::Positive;
997 let request_id = request_id.unwrap_or_else(|| format!("rated-{}-{}", device_id, time));
998
999 let mut tags = Vec::with_capacity(3);
1000 tags.push(if is_positive {
1001 "rated:positive".to_string()
1002 } else {
1003 "rated:negative".to_string()
1004 });
1005 if let Some(experiment) = experiment_name {
1006 tags.push(format!("experiment:{experiment}"));
1007 }
1008 if let Some(env) = environment {
1009 tags.push(format!("environment:{env}"));
1010 }
1011
1012 let mut example = build_example_from_snowflake(request_id, device_id, time, input, tags, None);
1013
1014 example.spec.rating = Some(parsed_rating);
1015
1016 if !feedback.is_empty() {
1017 example
1018 .spec
1019 .human_feedback
1020 .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
1021 }
1022
1023 if is_positive {
1024 example.spec.expected_patches = vec![output];
1025 } else {
1026 example.spec.rejected_patch = Some(output);
1027 }
1028
1029 example
1030}
1031
1032fn requested_examples_from_response<'a>(
1033 response: &'a SnowflakeStatementResponse,
1034 column_indices: &'a std::collections::HashMap<String, usize>,
1035) -> Result<impl Iterator<Item = Example> + 'a> {
1036 if let Some(code) = &response.code {
1037 if code != SNOWFLAKE_SUCCESS_CODE {
1038 anyhow::bail!(
1039 "snowflake sql api returned error code={code} message={}",
1040 response.message.as_deref().unwrap_or("<no message>")
1041 );
1042 }
1043 }
1044
1045 let iter = response
1046 .data
1047 .iter()
1048 .enumerate()
1049 .filter_map(move |(row_index, data_row)| {
1050 let get_string = |name: &str| -> Option<String> {
1051 let index = column_indices.get(name).copied()?;
1052 match data_row.get(index)? {
1053 JsonValue::String(s) => Some(s.clone()),
1054 JsonValue::Null => None,
1055 other => Some(other.to_string()),
1056 }
1057 };
1058
1059 let get_json = |name: &str| -> Option<JsonValue> {
1060 let index = column_indices.get(name).copied()?;
1061 let value = data_row.get(index)?;
1062 if value.is_null() {
1063 return None;
1064 }
1065 match value {
1066 JsonValue::String(s) => serde_json::from_str(s).ok(),
1067 other => Some(other.clone()),
1068 }
1069 };
1070
1071 let request_id_str = get_string("request_id");
1072 let device_id = get_string("device_id");
1073 let time = get_string("time");
1074 let input_json = get_json("input");
1075 let input: Option<ZetaPromptInput> =
1076 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1077
1078 match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
1079 (Some(request_id), Some(device_id), Some(time), Some(input)) => {
1080 Some(build_example_from_snowflake(
1081 request_id,
1082 device_id,
1083 time,
1084 input,
1085 vec!["requested".to_string()],
1086 None,
1087 ))
1088 }
1089 _ => {
1090 log::warn!(
1091 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
1092 request_id_str.is_some(),
1093 device_id.is_some(),
1094 time.is_some(),
1095 input_json.is_some(),
1096 );
1097 None
1098 }
1099 }
1100 });
1101
1102 Ok(iter)
1103}
1104
1105fn rejected_examples_from_response<'a>(
1106 response: &'a SnowflakeStatementResponse,
1107 column_indices: &'a std::collections::HashMap<String, usize>,
1108) -> Result<impl Iterator<Item = Example> + 'a> {
1109 if let Some(code) = &response.code {
1110 if code != SNOWFLAKE_SUCCESS_CODE {
1111 anyhow::bail!(
1112 "snowflake sql api returned error code={code} message={}",
1113 response.message.as_deref().unwrap_or("<no message>")
1114 );
1115 }
1116 }
1117
1118 let iter = response
1119 .data
1120 .iter()
1121 .enumerate()
1122 .filter_map(move |(row_index, data_row)| {
1123 let get_string = |name: &str| -> Option<String> {
1124 let index = column_indices.get(name).copied()?;
1125 match data_row.get(index)? {
1126 JsonValue::String(s) => Some(s.clone()),
1127 JsonValue::Null => None,
1128 other => Some(other.to_string()),
1129 }
1130 };
1131
1132 let get_json = |name: &str| -> Option<JsonValue> {
1133 let index = column_indices.get(name).copied()?;
1134 let value = data_row.get(index)?;
1135 if value.is_null() {
1136 return None;
1137 }
1138 match value {
1139 JsonValue::String(s) => serde_json::from_str(s).ok(),
1140 other => Some(other.clone()),
1141 }
1142 };
1143
1144 let get_bool = |name: &str| -> Option<bool> {
1145 let index = column_indices.get(name).copied()?;
1146 match data_row.get(index)? {
1147 JsonValue::Bool(b) => Some(*b),
1148 JsonValue::String(s) => s.parse().ok(),
1149 _ => None,
1150 }
1151 };
1152
1153 let request_id_str = get_string("request_id");
1154 let device_id = get_string("device_id");
1155 let time = get_string("time");
1156 let input_json = get_json("input");
1157 let input: Option<ZetaPromptInput> =
1158 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1159 let output = get_string("output");
1160 let was_shown = get_bool("was_shown");
1161 let reason = get_string("reason");
1162
1163 match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1164 (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1165 Some(build_rejected_example(
1166 request_id,
1167 device_id,
1168 time,
1169 input,
1170 output,
1171 was_shown,
1172 reason,
1173 ))
1174 }
1175 _ => {
1176 log::warn!(
1177 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1178 request_id_str.is_some(),
1179 device_id.is_some(),
1180 time.is_some(),
1181 input_json.is_some(),
1182 output.is_some(),
1183 was_shown.is_some(),
1184 reason.is_some()
1185 );
1186 None
1187 }
1188 }
1189 });
1190
1191 Ok(iter)
1192}
1193
1194fn build_rejected_example(
1195 request_id: String,
1196 device_id: String,
1197 time: String,
1198 input: ZetaPromptInput,
1199 output: String,
1200 was_shown: bool,
1201 reason: String,
1202) -> Example {
1203 let rejected_patch = build_output_patch(
1204 &input.cursor_path,
1205 input.cursor_excerpt.as_ref(),
1206 &input.editable_range_in_excerpt,
1207 &output,
1208 );
1209 let mut example = build_example_from_snowflake(
1210 request_id,
1211 device_id,
1212 time,
1213 input,
1214 vec![format!("rejection:{}", reason.to_lowercase())],
1215 Some(RejectionInfo { reason, was_shown }),
1216 );
1217 example.spec.rejected_patch = Some(rejected_patch);
1218 example
1219}
1220
1221struct RejectionInfo {
1222 reason: String,
1223 was_shown: bool,
1224}
1225
1226fn build_example_from_snowflake(
1227 request_id: String,
1228 device_id: String,
1229 time: String,
1230 input: ZetaPromptInput,
1231 tags: Vec<String>,
1232 rejection: Option<RejectionInfo>,
1233) -> Example {
1234 let events: Vec<CapturedEvent> = input
1235 .events
1236 .iter()
1237 .map(|event| match event.as_ref() {
1238 zeta_prompt::Event::BufferChange {
1239 path,
1240 old_path,
1241 diff,
1242 predicted,
1243 in_open_source_repo,
1244 } => CapturedEvent {
1245 path: path.clone(),
1246 old_path: old_path.clone(),
1247 diff: diff.clone(),
1248 predicted: *predicted,
1249 in_open_source_repo: *in_open_source_repo,
1250 },
1251 })
1252 .collect();
1253
1254 let related_files: Vec<CapturedRelatedFile> = input
1255 .related_files
1256 .iter()
1257 .map(|rf| CapturedRelatedFile {
1258 path: rf.path.clone(),
1259 max_row: rf.max_row,
1260 excerpts: rf
1261 .excerpts
1262 .iter()
1263 .map(|e| CapturedRelatedExcerpt {
1264 row_range: e.row_range.clone(),
1265 text: e.text.to_string(),
1266 })
1267 .collect(),
1268 })
1269 .collect();
1270
1271 let cursor_excerpt = input.cursor_excerpt.as_ref();
1272 let cursor_offset = input.cursor_offset_in_excerpt;
1273
1274 let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset);
1275
1276 let mut edit_history = String::new();
1277 for event in &input.events {
1278 zeta_prompt::write_event(&mut edit_history, event);
1279 edit_history.push('\n');
1280 }
1281
1282 let (rejection_reason, was_shown) = match &rejection {
1283 Some(r) => (r.reason.clone(), r.was_shown),
1284 None => (String::new(), false),
1285 };
1286
1287 let spec = ExampleSpec {
1288 name: request_id.clone(),
1289 repository_url: String::new(),
1290 revision: String::new(),
1291 tags,
1292 reasoning: None,
1293 uncommitted_diff: String::new(),
1294 cursor_path: input.cursor_path.clone(),
1295 cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1296 edit_history,
1297 expected_patches: Vec::new(),
1298 rejected_patch: None,
1299 captured_prompt_input: Some(CapturedPromptInput {
1300 cursor_file_content: cursor_excerpt.to_string(),
1301 cursor_offset,
1302 cursor_row,
1303 cursor_column,
1304 excerpt_start_row: None,
1305 events,
1306 related_files,
1307 }),
1308 telemetry: Some(TelemetrySource {
1309 request_id,
1310 device_id,
1311 time,
1312 rejection_reason,
1313 was_shown,
1314 }),
1315 human_feedback: Vec::new(),
1316 rating: None,
1317 };
1318
1319 Example {
1320 spec,
1321 prompt_inputs: None,
1322 prompt: None,
1323 predictions: Vec::new(),
1324 score: Vec::new(),
1325 qa: Vec::new(),
1326 state: None,
1327 }
1328}
1329
1330fn compute_row_column(text: &str, offset: usize) -> (u32, u32) {
1331 let mut row = 0u32;
1332 let mut last_newline_offset = 0;
1333 for (i, c) in text.char_indices() {
1334 if i >= offset {
1335 break;
1336 }
1337 if c == '\n' {
1338 row += 1;
1339 last_newline_offset = i + 1;
1340 }
1341 }
1342 let column = (offset - last_newline_offset) as u32;
1343 (row, column)
1344}
1345
1346fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1347 let before = &excerpt[..cursor_offset.min(excerpt.len())];
1348 let after = &excerpt[cursor_offset.min(excerpt.len())..];
1349 format!("{}[CURSOR_POSITION]{}", before, after)
1350}
1351
1352fn build_output_patch(
1353 cursor_path: &std::path::Path,
1354 cursor_excerpt: &str,
1355 editable_range: &std::ops::Range<usize>,
1356 model_output: &str,
1357) -> String {
1358 let old_text = &cursor_excerpt[editable_range.clone()];
1359
1360 let editable_start_row = cursor_excerpt[..editable_range.start]
1361 .chars()
1362 .filter(|&c| c == '\n')
1363 .count() as u32;
1364
1365 let diff_body = language::unified_diff_with_offsets(
1366 old_text,
1367 model_output,
1368 editable_start_row,
1369 editable_start_row,
1370 );
1371
1372 let mut patch = String::new();
1373 writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1374 writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1375 patch.push_str(&diff_body);
1376 patch
1377}
1378
1379pub(crate) fn get_column_indices(
1380 meta: &Option<SnowflakeResultSetMetaData>,
1381 names: &[&str],
1382) -> std::collections::HashMap<String, usize> {
1383 let mut indices = std::collections::HashMap::new();
1384 if let Some(meta) = meta {
1385 for (index, col) in meta.row_type.iter().enumerate() {
1386 for &name in names {
1387 if col.name.eq_ignore_ascii_case(name) {
1388 indices.insert(name.to_string(), index);
1389 }
1390 }
1391 }
1392 }
1393 indices
1394}