1use anyhow::{Context as _, Result};
2use flate2::read::GzDecoder;
3use gpui::BackgroundExecutor;
4use http_client::{AsyncBody, HttpClient, Method, Request};
5use indoc::indoc;
6use serde::Deserialize;
7use serde_json::{Value as JsonValue, json};
8use std::fmt::Write as _;
9use std::io::Read;
10use std::sync::Arc;
11use std::time::Duration;
12use telemetry_events::EditPredictionRating;
13
14use zeta_prompt::{ZetaFormat, ZetaPromptInput, excerpt_range_for_format};
15
16use crate::example::Example;
17use crate::progress::{InfoStyle, Progress, Step};
18const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment";
19use edit_prediction::example_spec::{ExampleSpec, TelemetrySource};
20
21pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
22pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
23const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
24const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
25const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
26const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
27
28/// Minimum Zed version for filtering captured examples.
29/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
30/// where `zed_version >= 0.224.1`.
31#[derive(Clone, Copy, Debug)]
32pub struct MinCaptureVersion {
33 pub minor: u32,
34 pub patch: u32,
35}
36
37const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
38const SETTLED_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
39pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
40pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
41
42/// Parse an input token of the form `captured-after:{timestamp}`.
43pub fn parse_captured_after_input(input: &str) -> Option<&str> {
44 input.strip_prefix("captured-after:")
45}
46
47/// Parse an input token of the form `rejected-after:{timestamp}`.
48pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
49 input.strip_prefix("rejected-after:")
50}
51
52/// Parse an input token of the form `requested-after:{timestamp}`.
53pub fn parse_requested_after_input(input: &str) -> Option<&str> {
54 input.strip_prefix("requested-after:")
55}
56
57/// Parse an input token of the form `settled-after:{timestamp}`.
58pub fn parse_settled_after_input(input: &str) -> Option<&str> {
59 input.strip_prefix("settled-after:")
60}
61
62/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
63/// or `rated-negative-after:{timestamp}`.
64/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
65pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
66 if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
67 Some((timestamp, Some(EditPredictionRating::Positive)))
68 } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
69 Some((timestamp, Some(EditPredictionRating::Negative)))
70 } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
71 Some((timestamp, None))
72 } else {
73 None
74 }
75}
76
77#[derive(Debug, Clone, Deserialize)]
78#[serde(rename_all = "camelCase")]
79pub(crate) struct SnowflakeStatementResponse {
80 #[serde(default)]
81 pub(crate) data: Vec<Vec<JsonValue>>,
82 #[serde(default)]
83 pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
84 #[serde(default)]
85 pub(crate) code: Option<String>,
86 #[serde(default)]
87 pub(crate) message: Option<String>,
88 #[serde(default)]
89 pub(crate) statement_handle: Option<String>,
90}
91
92#[derive(Debug, Clone, Deserialize)]
93#[serde(rename_all = "camelCase")]
94pub(crate) struct SnowflakeResultSetMetaData {
95 #[serde(default, rename = "rowType")]
96 row_type: Vec<SnowflakeColumnMeta>,
97 #[serde(default)]
98 num_rows: Option<i64>,
99 #[serde(default)]
100 partition_info: Vec<SnowflakePartitionInfo>,
101}
102
103#[derive(Debug, Clone, Deserialize)]
104#[serde(rename_all = "camelCase")]
105struct SnowflakePartitionInfo {}
106
107#[derive(Debug, Clone, Deserialize)]
108struct SnowflakeColumnMeta {
109 #[serde(default)]
110 name: String,
111}
112
113async fn run_sql_with_polling(
114 http_client: Arc<dyn HttpClient>,
115 base_url: &str,
116 token: &str,
117 request: &serde_json::Value,
118 step_progress: &crate::progress::StepProgress,
119 background_executor: BackgroundExecutor,
120) -> Result<SnowflakeStatementResponse> {
121 let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
122
123 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
124 let statement_handle = response
125 .statement_handle
126 .as_ref()
127 .context("async query response missing statementHandle")?
128 .clone();
129
130 for attempt in 1..=MAX_POLL_ATTEMPTS {
131 step_progress.set_substatus(format!("polling ({attempt})"));
132
133 background_executor.timer(POLL_INTERVAL).await;
134
135 response =
136 fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
137
138 if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
139 break;
140 }
141 }
142
143 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
144 anyhow::bail!(
145 "query still running after {} poll attempts ({} seconds)",
146 MAX_POLL_ATTEMPTS,
147 MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
148 );
149 }
150 }
151
152 Ok(response)
153}
154
155struct SnowflakeConfig {
156 token: String,
157 base_url: String,
158 role: Option<String>,
159}
160
161async fn fetch_examples_with_query(
162 http_client: Arc<dyn HttpClient>,
163 step_progress: &crate::progress::StepProgress,
164 background_executor: BackgroundExecutor,
165 statement: &str,
166 bindings: JsonValue,
167 timeout_seconds: u64,
168 required_columns: &[&str],
169 parse_response: for<'a> fn(
170 &'a SnowflakeStatementResponse,
171 &'a std::collections::HashMap<String, usize>,
172 ) -> Result<Box<dyn Iterator<Item = Example> + 'a>>,
173) -> Result<Vec<Example>> {
174 let snowflake = SnowflakeConfig {
175 token: std::env::var("EP_SNOWFLAKE_API_KEY")
176 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?,
177 base_url: std::env::var("EP_SNOWFLAKE_BASE_URL").context(
178 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
179 )?,
180 role: std::env::var("EP_SNOWFLAKE_ROLE").ok(),
181 };
182 let request = json!({
183 "statement": statement,
184 "timeout": timeout_seconds,
185 "database": "EVENTS",
186 "schema": "PUBLIC",
187 "warehouse": "DBT",
188 "role": snowflake.role.as_deref(),
189 "bindings": bindings
190 });
191
192 let response = run_sql_with_polling(
193 http_client.clone(),
194 &snowflake.base_url,
195 &snowflake.token,
196 &request,
197 step_progress,
198 background_executor,
199 )
200 .await?;
201
202 let total_rows = response
203 .result_set_meta_data
204 .as_ref()
205 .and_then(|meta| meta.num_rows)
206 .unwrap_or(response.data.len() as i64);
207 let partition_count = response
208 .result_set_meta_data
209 .as_ref()
210 .map(|meta| meta.partition_info.len())
211 .unwrap_or(1)
212 .max(1);
213
214 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
215 step_progress.set_substatus("parsing");
216
217 let column_indices = get_column_indices(&response.result_set_meta_data, required_columns);
218
219 let mut parsed_examples = Vec::with_capacity(total_rows as usize);
220 parsed_examples.extend(parse_response(&response, &column_indices)?);
221
222 if partition_count > 1 {
223 let statement_handle = response
224 .statement_handle
225 .as_ref()
226 .context("response has multiple partitions but no statementHandle")?;
227
228 for partition in 1..partition_count {
229 step_progress.set_substatus(format!(
230 "fetching partition {}/{}",
231 partition + 1,
232 partition_count
233 ));
234
235 let partition_response = fetch_partition(
236 http_client.clone(),
237 &snowflake.base_url,
238 &snowflake.token,
239 statement_handle,
240 partition,
241 )
242 .await?;
243
244 parsed_examples.extend(parse_response(&partition_response, &column_indices)?);
245 }
246 }
247
248 step_progress.set_substatus("done");
249 Ok(parsed_examples)
250}
251
252pub(crate) async fn fetch_partition(
253 http_client: Arc<dyn HttpClient>,
254 base_url: &str,
255 token: &str,
256 statement_handle: &str,
257 partition: usize,
258) -> Result<SnowflakeStatementResponse> {
259 let url = format!(
260 "{}/api/v2/statements/{}?partition={}",
261 base_url.trim_end_matches('/'),
262 statement_handle,
263 partition
264 );
265
266 let http_request = Request::builder()
267 .method(Method::GET)
268 .uri(url.as_str())
269 .header("Authorization", format!("Bearer {token}"))
270 .header(
271 "X-Snowflake-Authorization-Token-Type",
272 "PROGRAMMATIC_ACCESS_TOKEN",
273 )
274 .header("Accept", "application/json")
275 .header("Accept-Encoding", "gzip")
276 .header("User-Agent", "edit_prediction_cli")
277 .body(AsyncBody::empty())?;
278
279 let response = http_client
280 .send(http_request)
281 .await
282 .context("failed to send partition request to Snowflake SQL API")?;
283
284 let status = response.status();
285 let content_encoding = response
286 .headers()
287 .get("content-encoding")
288 .and_then(|v| v.to_str().ok())
289 .map(|s| s.to_lowercase());
290
291 let body_bytes = {
292 use futures::AsyncReadExt as _;
293
294 let mut body = response.into_body();
295 let mut bytes = Vec::new();
296 body.read_to_end(&mut bytes)
297 .await
298 .context("failed to read Snowflake SQL API partition response body")?;
299 bytes
300 };
301
302 let body_bytes = if content_encoding.as_deref() == Some("gzip") {
303 let mut decoder = GzDecoder::new(&body_bytes[..]);
304 let mut decompressed = Vec::new();
305 decoder
306 .read_to_end(&mut decompressed)
307 .context("failed to decompress gzip response")?;
308 decompressed
309 } else {
310 body_bytes
311 };
312
313 if !status.is_success() && status.as_u16() != 202 {
314 let body_text = String::from_utf8_lossy(&body_bytes);
315 anyhow::bail!(
316 "snowflake sql api partition request http {}: {}",
317 status.as_u16(),
318 body_text
319 );
320 }
321
322 if body_bytes.is_empty() {
323 anyhow::bail!(
324 "snowflake sql api partition {} returned empty response body (http {})",
325 partition,
326 status.as_u16()
327 );
328 }
329
330 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
331 let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
332 format!(
333 "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
334 partition,
335 status.as_u16(),
336 body_preview
337 )
338 })
339}
340
341pub(crate) async fn run_sql(
342 http_client: Arc<dyn HttpClient>,
343 base_url: &str,
344 token: &str,
345 request: &serde_json::Value,
346) -> Result<SnowflakeStatementResponse> {
347 let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
348
349 let request_body =
350 serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
351
352 let http_request = Request::builder()
353 .method(Method::POST)
354 .uri(url.as_str())
355 .header("Authorization", format!("Bearer {token}"))
356 .header(
357 "X-Snowflake-Authorization-Token-Type",
358 "PROGRAMMATIC_ACCESS_TOKEN",
359 )
360 .header("Content-Type", "application/json")
361 .header("Accept", "application/json")
362 .header("User-Agent", "edit_prediction_cli")
363 .body(AsyncBody::from(request_body.clone()))?;
364
365 let response = http_client
366 .send(http_request)
367 .await
368 .context("failed to send request to Snowflake SQL API")?;
369
370 let status = response.status();
371 let body_bytes = {
372 use futures::AsyncReadExt as _;
373
374 let mut body = response.into_body();
375 let mut bytes = Vec::new();
376 body.read_to_end(&mut bytes)
377 .await
378 .context("failed to read Snowflake SQL API response body")?;
379 bytes
380 };
381
382 if !status.is_success() && status.as_u16() != 202 {
383 let body_text = String::from_utf8_lossy(&body_bytes);
384 anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
385 }
386
387 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
388 .context("failed to parse Snowflake SQL API response JSON")
389}
390
391pub async fn fetch_rejected_examples_after(
392 http_client: Arc<dyn HttpClient>,
393 after_timestamps: &[String],
394 max_rows_per_timestamp: usize,
395 offset: usize,
396 background_executor: BackgroundExecutor,
397 min_capture_version: Option<MinCaptureVersion>,
398) -> Result<Vec<Example>> {
399 if after_timestamps.is_empty() {
400 return Ok(Vec::new());
401 }
402
403 let progress = Progress::global();
404
405 let mut all_examples = Vec::new();
406
407 for after_date in after_timestamps.iter() {
408 let step_progress_name = format!("rejected>{after_date}");
409 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
410 step_progress.set_substatus("querying");
411
412 let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
413 let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
414 let min_minor_str_ref = min_minor_str.as_deref();
415 let min_patch_str_ref = min_patch_str.as_deref();
416
417 let statement = indoc! {r#"
418 SELECT
419 req.event_properties:request_id::string AS request_id,
420 req.device_id::string AS device_id,
421 req.time::string AS time,
422 req.event_properties:input AS input,
423 req.event_properties:prompt::string AS prompt,
424 req.event_properties:output::string AS output,
425 rej.event_properties:was_shown::boolean AS was_shown,
426 rej.event_properties:reason::string AS reason,
427 req.event_properties:zed_version::string AS zed_version
428 FROM events req
429 INNER JOIN events rej
430 ON req.event_properties:request_id = rej.event_properties:request_id
431 WHERE req.event_type = ?
432 AND rej.event_type = ?
433 AND req.event_properties:version = 'V3'
434 AND rej.event_properties:was_shown = true
435 AND req.event_properties:input:can_collect_data = true
436 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
437 AND (? IS NULL OR (
438 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
439 OR (
440 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
441 AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
442 )
443 ))
444 ORDER BY req.time ASC
445 LIMIT ?
446 OFFSET ?
447 "#};
448
449 let bindings = json!({
450 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
451 "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
452 "3": { "type": "TEXT", "value": after_date },
453 "4": { "type": "FIXED", "value": min_minor_str_ref },
454 "5": { "type": "FIXED", "value": min_minor_str_ref },
455 "6": { "type": "FIXED", "value": min_minor_str_ref },
456 "7": { "type": "FIXED", "value": min_patch_str_ref },
457 "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
458 "9": { "type": "FIXED", "value": offset.to_string() }
459 });
460
461 let examples = fetch_examples_with_query(
462 http_client.clone(),
463 &step_progress,
464 background_executor.clone(),
465 statement,
466 bindings,
467 DEFAULT_STATEMENT_TIMEOUT_SECONDS,
468 &[
469 "request_id",
470 "device_id",
471 "time",
472 "input",
473 "prompt",
474 "output",
475 "was_shown",
476 "reason",
477 "zed_version",
478 ],
479 rejected_examples_from_response,
480 )
481 .await?;
482
483 all_examples.extend(examples);
484 }
485
486 Ok(all_examples)
487}
488
489pub async fn fetch_requested_examples_after(
490 http_client: Arc<dyn HttpClient>,
491 after_timestamps: &[String],
492 max_rows_per_timestamp: usize,
493 offset: usize,
494 background_executor: BackgroundExecutor,
495 min_capture_version: Option<MinCaptureVersion>,
496) -> Result<Vec<Example>> {
497 if after_timestamps.is_empty() {
498 return Ok(Vec::new());
499 }
500
501 let progress = Progress::global();
502
503 let mut all_examples = Vec::new();
504
505 for after_date in after_timestamps.iter() {
506 let step_progress_name = format!("requested>{after_date}");
507 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
508 step_progress.set_substatus("querying");
509
510 let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
511 let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
512 let min_minor_str_ref = min_minor_str.as_deref();
513 let min_patch_str_ref = min_patch_str.as_deref();
514
515 let statement = indoc! {r#"
516 SELECT
517 req.event_properties:request_id::string AS request_id,
518 req.device_id::string AS device_id,
519 req.time::string AS time,
520 req.event_properties:input AS input,
521 req.event_properties:zed_version::string AS zed_version
522 FROM events req
523 WHERE req.event_type = ?
524 AND req.event_properties:version = 'V3'
525 AND req.event_properties:input:can_collect_data = true
526 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
527 AND (? IS NULL OR (
528 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
529 OR (
530 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
531 AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
532 )
533 ))
534 ORDER BY req.time ASC
535 LIMIT ?
536 OFFSET ?
537 "#};
538
539 let bindings = json!({
540 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
541 "2": { "type": "TEXT", "value": after_date },
542 "3": { "type": "FIXED", "value": min_minor_str_ref },
543 "4": { "type": "FIXED", "value": min_minor_str_ref },
544 "5": { "type": "FIXED", "value": min_minor_str_ref },
545 "6": { "type": "FIXED", "value": min_patch_str_ref },
546 "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
547 "8": { "type": "FIXED", "value": offset.to_string() }
548 });
549
550 let examples = fetch_examples_with_query(
551 http_client.clone(),
552 &step_progress,
553 background_executor.clone(),
554 statement,
555 bindings,
556 DEFAULT_STATEMENT_TIMEOUT_SECONDS,
557 &["request_id", "device_id", "time", "input", "zed_version"],
558 requested_examples_from_response,
559 )
560 .await?;
561
562 all_examples.extend(examples);
563 }
564
565 Ok(all_examples)
566}
567
568pub async fn fetch_settled_examples_after(
569 http_client: Arc<dyn HttpClient>,
570 after_timestamps: &[String],
571 max_rows_per_timestamp: usize,
572 offset: usize,
573 background_executor: BackgroundExecutor,
574 min_capture_version: Option<MinCaptureVersion>,
575) -> Result<Vec<Example>> {
576 if after_timestamps.is_empty() {
577 return Ok(Vec::new());
578 }
579
580 let progress = Progress::global();
581
582 let mut all_examples = Vec::new();
583
584 for after_date in after_timestamps.iter() {
585 let step_progress_name = format!("settled>{after_date}");
586 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
587 step_progress.set_substatus("querying");
588
589 let _ = min_capture_version;
590
591 let statement = indoc! {r#"
592 WITH requested AS (
593 SELECT
594 req.event_properties:request_id::string AS request_id,
595 req.device_id::string AS device_id,
596 req.time AS req_time,
597 req.time::string AS time,
598 req.event_properties:input AS input,
599 req.event_properties:format::string AS requested_format,
600 req.event_properties:output::string AS requested_output,
601 req.event_properties:zed_version::string AS zed_version
602 FROM events req
603 WHERE req.event_type = ?
604 AND req.event_properties:version = 'V3'
605 AND req.event_properties:input:can_collect_data = true
606 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
607 )
608 SELECT
609 req.request_id AS request_id,
610 req.device_id AS device_id,
611 req.time AS time,
612 req.input AS input,
613 req.requested_output AS requested_output,
614 settled.event_properties:settled_editable_region::string AS settled_editable_region,
615 req.requested_format AS requested_format,
616 req.zed_version AS zed_version
617 FROM requested req
618 INNER JOIN events settled
619 ON req.request_id = settled.event_properties:request_id::string
620 WHERE settled.event_type = ?
621 ORDER BY req.req_time ASC
622 LIMIT ?
623 OFFSET ?
624 "#};
625
626 let bindings = json!({
627 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
628 "2": { "type": "TEXT", "value": after_date },
629 "3": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT },
630 "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
631 "5": { "type": "FIXED", "value": offset.to_string() }
632 });
633
634 let examples = fetch_examples_with_query(
635 http_client.clone(),
636 &step_progress,
637 background_executor.clone(),
638 statement,
639 bindings,
640 SETTLED_STATEMENT_TIMEOUT_SECONDS,
641 &[
642 "request_id",
643 "device_id",
644 "time",
645 "input",
646 "requested_output",
647 "settled_editable_region",
648 "requested_format",
649 "zed_version",
650 ],
651 settled_examples_from_response,
652 )
653 .await?;
654
655 all_examples.extend(examples);
656 }
657
658 Ok(all_examples)
659}
660
661pub async fn fetch_rated_examples_after(
662 http_client: Arc<dyn HttpClient>,
663 inputs: &[(String, Option<EditPredictionRating>)],
664 max_rows_per_timestamp: usize,
665 offset: usize,
666 background_executor: BackgroundExecutor,
667 _min_capture_version: Option<MinCaptureVersion>,
668) -> Result<Vec<Example>> {
669 if inputs.is_empty() {
670 return Ok(Vec::new());
671 }
672
673 let progress = Progress::global();
674
675 let mut all_examples = Vec::new();
676
677 for (after_date, rating_filter) in inputs.iter() {
678 let filter_label = match rating_filter {
679 None => "",
680 Some(EditPredictionRating::Positive) => ":positive",
681 Some(EditPredictionRating::Negative) => ":negative",
682 };
683 let step_progress_name = format!("rated{filter_label}>{after_date}");
684 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
685 step_progress.set_substatus("querying");
686
687 let rating_value = rating_filter.as_ref().map(|rating| match rating {
688 EditPredictionRating::Positive => "Positive",
689 EditPredictionRating::Negative => "Negative",
690 });
691
692 let statement = indoc! {r#"
693 SELECT
694 rated.event_properties:request_id::string AS request_id,
695 rated.event_properties:inputs AS inputs,
696 rated.event_properties:output::string AS output,
697 rated.event_properties:rating::string AS rating,
698 rated.event_properties:feedback::string AS feedback,
699 rated.device_id::string AS device_id,
700 rated.time::string AS time,
701 deploy.event_properties:experiment_name::string AS experiment_name,
702 deploy.event_properties:environment::string AS environment,
703 rated.event_properties:zed_version::string AS zed_version
704 FROM events rated
705 LEFT JOIN events req
706 ON rated.event_properties:request_id::string = req.event_properties:request_id::string
707 AND req.event_type = ?
708 LEFT JOIN events deploy
709 ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string
710 AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string
711 AND deploy.event_type = ?
712 WHERE rated.event_type = ?
713 AND (? IS NULL OR rated.event_properties:rating::string = ?)
714 AND rated.time > TRY_TO_TIMESTAMP_NTZ(?)
715 AND rated.event_properties:inputs IS NOT NULL
716 AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
717 AND rated.event_properties:output IS NOT NULL
718 AND rated.event_properties:inputs:can_collect_data = true
719 ORDER BY rated.time ASC
720 LIMIT ?
721 OFFSET ?
722 "#};
723
724 let bindings = json!({
725 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
726 "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
727 "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
728 "4": { "type": "TEXT", "value": rating_value },
729 "5": { "type": "TEXT", "value": rating_value },
730 "6": { "type": "TEXT", "value": after_date },
731 "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
732 "8": { "type": "FIXED", "value": offset.to_string() }
733 });
734
735 let examples = fetch_examples_with_query(
736 http_client.clone(),
737 &step_progress,
738 background_executor.clone(),
739 statement,
740 bindings,
741 DEFAULT_STATEMENT_TIMEOUT_SECONDS,
742 &[
743 "request_id",
744 "inputs",
745 "output",
746 "rating",
747 "feedback",
748 "device_id",
749 "time",
750 "experiment_name",
751 "environment",
752 "zed_version",
753 ],
754 rated_examples_from_response,
755 )
756 .await?;
757
758 all_examples.extend(examples);
759 }
760
761 Ok(all_examples)
762}
763
764fn rated_examples_from_response<'a>(
765 response: &'a SnowflakeStatementResponse,
766 column_indices: &'a std::collections::HashMap<String, usize>,
767) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
768 if let Some(code) = &response.code {
769 if code != SNOWFLAKE_SUCCESS_CODE {
770 anyhow::bail!(
771 "snowflake sql api returned error code={code} message={}",
772 response.message.as_deref().unwrap_or("<no message>")
773 );
774 }
775 }
776
777 let iter = response
778 .data
779 .iter()
780 .enumerate()
781 .filter_map(move |(row_index, data_row)| {
782 let get_string = |name: &str| -> Option<String> {
783 let index = column_indices.get(name).copied()?;
784 match data_row.get(index)? {
785 JsonValue::String(s) => Some(s.clone()),
786 JsonValue::Null => None,
787 other => Some(other.to_string()),
788 }
789 };
790
791 let get_json = |name: &str| -> Option<JsonValue> {
792 let index = column_indices.get(name).copied()?;
793 let value = data_row.get(index)?;
794 if value.is_null() {
795 return None;
796 }
797 match value {
798 JsonValue::String(s) => serde_json::from_str(s).ok(),
799 other => Some(other.clone()),
800 }
801 };
802
803 let request_id = get_string("request_id");
804 let inputs_json = get_json("inputs");
805 let inputs: Option<ZetaPromptInput> = match &inputs_json {
806 Some(v) => match serde_json::from_value(v.clone()) {
807 Ok(parsed) => Some(parsed),
808 Err(e) => {
809 log::warn!(
810 "skipping row {row_index}: failed to parse inputs - {e}",
811 );
812 return None;
813 }
814 },
815 None => None,
816 };
817 let output = get_string("output");
818 let rating = get_string("rating");
819 let feedback = get_string("feedback").unwrap_or_default();
820 let device_id = get_string("device_id");
821 let time = get_string("time");
822 let experiment_name = get_string("experiment_name");
823 let environment = get_string("environment");
824 let zed_version = get_string("zed_version");
825
826 match (inputs, output.clone(), rating.clone(), time.clone()) {
827 (Some(inputs), Some(output), Some(rating), Some(time)) => {
828 Some(build_rated_example(
829 request_id,
830 device_id.unwrap_or_default(),
831 time,
832 inputs,
833 output,
834 rating,
835 feedback,
836 experiment_name,
837 environment,
838 zed_version,
839 ))
840 }
841 _ => {
842 log::warn!(
843 "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} time={:?}",
844 inputs_json.is_some(),
845 output.is_some(),
846 rating.is_some(),
847 time.is_some(),
848 );
849 None
850 }
851 }
852 });
853
854 Ok(Box::new(iter))
855}
856
857fn build_rated_example(
858 request_id: Option<String>,
859 device_id: String,
860 time: String,
861 input: ZetaPromptInput,
862 output: String,
863 rating: String,
864 feedback: String,
865 experiment_name: Option<String>,
866 environment: Option<String>,
867 zed_version: Option<String>,
868) -> Example {
869 let parsed_rating = if rating == "Positive" {
870 EditPredictionRating::Positive
871 } else {
872 EditPredictionRating::Negative
873 };
874 let is_positive = parsed_rating == EditPredictionRating::Positive;
875 let request_id = request_id.unwrap_or_else(|| format!("rated-{}-{}", device_id, time));
876
877 let mut tags = Vec::with_capacity(3);
878 tags.push(if is_positive {
879 "rated:positive".to_string()
880 } else {
881 "rated:negative".to_string()
882 });
883 if let Some(experiment) = experiment_name {
884 tags.push(format!("experiment:{experiment}"));
885 }
886 if let Some(env) = environment {
887 tags.push(format!("environment:{env}"));
888 }
889
890 let mut example =
891 build_example_from_snowflake(request_id, device_id, time, input, tags, None, zed_version);
892
893 example.spec.rating = Some(parsed_rating);
894
895 if !feedback.is_empty() {
896 example
897 .spec
898 .human_feedback
899 .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
900 }
901
902 if is_positive {
903 example.spec.expected_patches = vec![output];
904 } else {
905 example.spec.rejected_patch = Some(output);
906 }
907
908 example
909}
910
911fn requested_examples_from_response<'a>(
912 response: &'a SnowflakeStatementResponse,
913 column_indices: &'a std::collections::HashMap<String, usize>,
914) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
915 if let Some(code) = &response.code {
916 if code != SNOWFLAKE_SUCCESS_CODE {
917 anyhow::bail!(
918 "snowflake sql api returned error code={code} message={}",
919 response.message.as_deref().unwrap_or("<no message>")
920 );
921 }
922 }
923
924 let iter = response
925 .data
926 .iter()
927 .enumerate()
928 .filter_map(move |(row_index, data_row)| {
929 let get_string = |name: &str| -> Option<String> {
930 let index = column_indices.get(name).copied()?;
931 match data_row.get(index)? {
932 JsonValue::String(s) => Some(s.clone()),
933 JsonValue::Null => None,
934 other => Some(other.to_string()),
935 }
936 };
937
938 let get_json = |name: &str| -> Option<JsonValue> {
939 let index = column_indices.get(name).copied()?;
940 let value = data_row.get(index)?;
941 if value.is_null() {
942 return None;
943 }
944 match value {
945 JsonValue::String(s) => serde_json::from_str(s).ok(),
946 other => Some(other.clone()),
947 }
948 };
949
950 let request_id_str = get_string("request_id");
951 let device_id = get_string("device_id");
952 let time = get_string("time");
953 let input_json = get_json("input");
954 let input: Option<ZetaPromptInput> =
955 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
956 let zed_version = get_string("zed_version");
957
958 match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
959 (Some(request_id), Some(device_id), Some(time), Some(input)) => {
960 Some(build_example_from_snowflake(
961 request_id,
962 device_id,
963 time,
964 input,
965 vec!["requested".to_string()],
966 None,
967 zed_version,
968 ))
969 }
970 _ => {
971 log::warn!(
972 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
973 request_id_str.is_some(),
974 device_id.is_some(),
975 time.is_some(),
976 input_json.is_some(),
977 );
978 None
979 }
980 }
981 });
982
983 Ok(Box::new(iter))
984}
985
986fn settled_examples_from_response<'a>(
987 response: &'a SnowflakeStatementResponse,
988 column_indices: &'a std::collections::HashMap<String, usize>,
989) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
990 if let Some(code) = &response.code {
991 if code != SNOWFLAKE_SUCCESS_CODE {
992 anyhow::bail!(
993 "snowflake sql api returned error code={code} message={}",
994 response.message.as_deref().unwrap_or("<no message>")
995 );
996 }
997 }
998
999 let iter = response
1000 .data
1001 .iter()
1002 .enumerate()
1003 .filter_map(move |(row_index, data_row)| {
1004 let get_value = |name: &str| -> Option<JsonValue> {
1005 let index = column_indices.get(name).copied()?;
1006 let value = data_row.get(index)?;
1007 if value.is_null() {
1008 None
1009 } else {
1010 Some(value.clone())
1011 }
1012 };
1013
1014 let get_string = |name: &str| -> Option<String> {
1015 match get_value(name)? {
1016 JsonValue::String(s) => Some(s),
1017 other => Some(other.to_string()),
1018 }
1019 };
1020
1021 let parse_json_value = |_: &str, raw: Option<&JsonValue>| -> Option<JsonValue> {
1022 let value = raw?;
1023 match value {
1024 JsonValue::String(s) => serde_json::from_str::<JsonValue>(s).ok(),
1025 other => Some(other.clone()),
1026 }
1027 };
1028
1029 let request_id_str = get_string("request_id");
1030 let device_id = get_string("device_id");
1031 let time = get_string("time");
1032 let input_raw = get_value("input");
1033 let input_json = parse_json_value("input", input_raw.as_ref());
1034 let input: Option<ZetaPromptInput> = input_json
1035 .as_ref()
1036 .and_then(|parsed| serde_json::from_value(parsed.clone()).ok());
1037 let requested_output = get_string("requested_output");
1038 let settled_editable_region = get_string("settled_editable_region");
1039 let requested_format =
1040 get_string("requested_format").and_then(|s| ZetaFormat::parse(&s).ok());
1041 let zed_version = get_string("zed_version");
1042
1043 match (
1044 request_id_str.clone(),
1045 device_id.clone(),
1046 time.clone(),
1047 input.clone(),
1048 requested_output.clone(),
1049 settled_editable_region.clone(),
1050 requested_format,
1051 ) {
1052 (
1053 Some(request_id),
1054 Some(device_id),
1055 Some(time),
1056 Some(input),
1057 Some(requested_output),
1058 Some(settled_editable_region),
1059 Some(requested_format),
1060 ) => Some(build_settled_example(
1061 request_id,
1062 device_id,
1063 time,
1064 input,
1065 requested_output,
1066 settled_editable_region,
1067 requested_format,
1068 zed_version,
1069 )),
1070 _ => {
1071 let mut missing_fields = Vec::new();
1072
1073 if request_id_str.is_none() {
1074 missing_fields.push("request_id");
1075 }
1076 if device_id.is_none() {
1077 missing_fields.push("device_id");
1078 }
1079 if time.is_none() {
1080 missing_fields.push("time");
1081 }
1082 if input_raw.is_none() || input_json.is_none() || input.is_none() {
1083 missing_fields.push("input");
1084 }
1085 if requested_output.is_none() {
1086 missing_fields.push("requested_output");
1087 }
1088 if settled_editable_region.is_none() {
1089 missing_fields.push("settled_editable_region");
1090 }
1091 if requested_format.is_none() {
1092 missing_fields.push("requested_format");
1093 }
1094
1095 log::warn!(
1096 "skipping settled row {row_index}: [{}]",
1097 missing_fields.join(", "),
1098 );
1099 None
1100 }
1101 }
1102 });
1103
1104 Ok(Box::new(iter))
1105}
1106
1107fn build_settled_example(
1108 request_id: String,
1109 device_id: String,
1110 time: String,
1111 input: ZetaPromptInput,
1112 requested_output: String,
1113 settled_editable_region: String,
1114 requested_format: ZetaFormat,
1115 zed_version: Option<String>,
1116) -> Example {
1117 let requested_editable_range =
1118 excerpt_range_for_format(requested_format, &input.excerpt_ranges).0;
1119
1120 let base_cursor_excerpt = input.cursor_excerpt.to_string();
1121
1122 let requested_range_is_valid = requested_editable_range.start <= requested_editable_range.end
1123 && requested_editable_range.end <= base_cursor_excerpt.len();
1124 let mut example = build_example_from_snowflake(
1125 request_id.clone(),
1126 device_id,
1127 time,
1128 input,
1129 vec!["settled".to_string()],
1130 None,
1131 zed_version,
1132 );
1133
1134 if !requested_range_is_valid {
1135 log::warn!(
1136 "skipping malformed requested range for request {}: requested={:?} (base_len={})",
1137 request_id,
1138 requested_editable_range,
1139 base_cursor_excerpt.len(),
1140 );
1141 return example;
1142 }
1143
1144 let settled_replacement = settled_editable_region.as_str();
1145 let rejected_patch = build_output_patch(
1146 &example.spec.cursor_path,
1147 &base_cursor_excerpt,
1148 &requested_editable_range,
1149 &requested_output,
1150 );
1151 let expected_patch = build_output_patch(
1152 &example.spec.cursor_path,
1153 &base_cursor_excerpt,
1154 &requested_editable_range,
1155 settled_replacement,
1156 );
1157
1158 example.spec.expected_patches = vec![expected_patch];
1159 example.spec.rejected_patch = Some(rejected_patch);
1160 example
1161}
1162
1163fn rejected_examples_from_response<'a>(
1164 response: &'a SnowflakeStatementResponse,
1165 column_indices: &'a std::collections::HashMap<String, usize>,
1166) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1167 if let Some(code) = &response.code {
1168 if code != SNOWFLAKE_SUCCESS_CODE {
1169 anyhow::bail!(
1170 "snowflake sql api returned error code={code} message={}",
1171 response.message.as_deref().unwrap_or("<no message>")
1172 );
1173 }
1174 }
1175
1176 let iter = response
1177 .data
1178 .iter()
1179 .enumerate()
1180 .filter_map(move |(row_index, data_row)| {
1181 let get_string = |name: &str| -> Option<String> {
1182 let index = column_indices.get(name).copied()?;
1183 match data_row.get(index)? {
1184 JsonValue::String(s) => Some(s.clone()),
1185 JsonValue::Null => None,
1186 other => Some(other.to_string()),
1187 }
1188 };
1189
1190 let get_json = |name: &str| -> Option<JsonValue> {
1191 let index = column_indices.get(name).copied()?;
1192 let value = data_row.get(index)?;
1193 if value.is_null() {
1194 return None;
1195 }
1196 match value {
1197 JsonValue::String(s) => serde_json::from_str(s).ok(),
1198 other => Some(other.clone()),
1199 }
1200 };
1201
1202 let get_bool = |name: &str| -> Option<bool> {
1203 let index = column_indices.get(name).copied()?;
1204 match data_row.get(index)? {
1205 JsonValue::Bool(b) => Some(*b),
1206 JsonValue::String(s) => s.parse().ok(),
1207 _ => None,
1208 }
1209 };
1210
1211 let request_id_str = get_string("request_id");
1212 let device_id = get_string("device_id");
1213 let time = get_string("time");
1214 let input_json = get_json("input");
1215 let input: Option<ZetaPromptInput> =
1216 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1217 let output = get_string("output");
1218 let was_shown = get_bool("was_shown");
1219 let reason = get_string("reason");
1220 let zed_version = get_string("zed_version");
1221
1222 match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1223 (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1224 Some(build_rejected_example(
1225 request_id,
1226 device_id,
1227 time,
1228 input,
1229 output,
1230 was_shown,
1231 reason,
1232 zed_version,
1233 ))
1234 }
1235 _ => {
1236 log::warn!(
1237 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1238 request_id_str.is_some(),
1239 device_id.is_some(),
1240 time.is_some(),
1241 input_json.is_some(),
1242 output.is_some(),
1243 was_shown.is_some(),
1244 reason.is_some()
1245 );
1246 None
1247 }
1248 }
1249 });
1250
1251 Ok(Box::new(iter))
1252}
1253
1254fn build_rejected_example(
1255 request_id: String,
1256 device_id: String,
1257 time: String,
1258 input: ZetaPromptInput,
1259 output: String,
1260 was_shown: bool,
1261 reason: String,
1262 zed_version: Option<String>,
1263) -> Example {
1264 let rejected_patch = build_output_patch(
1265 &input.cursor_path,
1266 input.cursor_excerpt.as_ref(),
1267 &input.excerpt_ranges.editable_350,
1268 &output,
1269 );
1270 let mut example = build_example_from_snowflake(
1271 request_id,
1272 device_id,
1273 time,
1274 input,
1275 vec![format!("rejection:{}", reason.to_lowercase())],
1276 Some(RejectionInfo { reason, was_shown }),
1277 zed_version,
1278 );
1279 example.spec.rejected_patch = Some(rejected_patch);
1280 example
1281}
1282
1283struct RejectionInfo {
1284 reason: String,
1285 was_shown: bool,
1286}
1287
1288fn build_example_from_snowflake(
1289 request_id: String,
1290 device_id: String,
1291 time: String,
1292 input: ZetaPromptInput,
1293 tags: Vec<String>,
1294 rejection: Option<RejectionInfo>,
1295 zed_version: Option<String>,
1296) -> Example {
1297 let cursor_excerpt = input.cursor_excerpt.as_ref();
1298 let cursor_offset = input.cursor_offset_in_excerpt;
1299
1300 let mut edit_history = String::new();
1301 for event in &input.events {
1302 zeta_prompt::write_event(&mut edit_history, event);
1303 edit_history.push('\n');
1304 }
1305
1306 let (rejection_reason, was_shown) = match &rejection {
1307 Some(r) => (r.reason.clone(), r.was_shown),
1308 None => (String::new(), false),
1309 };
1310
1311 let spec = ExampleSpec {
1312 name: request_id.clone(),
1313 repository_url: String::new(),
1314 revision: String::new(),
1315 tags,
1316 reasoning: None,
1317 uncommitted_diff: String::new(),
1318 cursor_path: input.cursor_path.clone(),
1319 cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1320 edit_history,
1321 expected_patches: Vec::new(),
1322 rejected_patch: None,
1323 telemetry: Some(TelemetrySource {
1324 request_id,
1325 device_id,
1326 time,
1327 rejection_reason,
1328 was_shown,
1329 }),
1330 human_feedback: Vec::new(),
1331 rating: None,
1332 };
1333
1334 Example {
1335 spec,
1336 zed_version,
1337 prompt_inputs: Some(input),
1338 prompt: None,
1339 predictions: Vec::new(),
1340 score: Vec::new(),
1341 qa: Vec::new(),
1342 state: None,
1343 }
1344}
1345
1346fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1347 let before = &excerpt[..cursor_offset.min(excerpt.len())];
1348 let after = &excerpt[cursor_offset.min(excerpt.len())..];
1349 format!("{}[CURSOR_POSITION]{}", before, after)
1350}
1351
1352fn build_output_patch(
1353 cursor_path: &std::path::Path,
1354 cursor_excerpt: &str,
1355 editable_range: &std::ops::Range<usize>,
1356 model_output: &str,
1357) -> String {
1358 let old_text = &cursor_excerpt[editable_range.clone()];
1359
1360 let editable_start_row = cursor_excerpt[..editable_range.start]
1361 .chars()
1362 .filter(|&c| c == '\n')
1363 .count() as u32;
1364
1365 let diff_body = language::unified_diff_with_offsets(
1366 old_text,
1367 model_output,
1368 editable_start_row,
1369 editable_start_row,
1370 );
1371
1372 let mut patch = String::new();
1373 writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1374 writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1375 patch.push_str(&diff_body);
1376 patch
1377}
1378
1379pub(crate) fn get_column_indices(
1380 meta: &Option<SnowflakeResultSetMetaData>,
1381 names: &[&str],
1382) -> std::collections::HashMap<String, usize> {
1383 let mut indices = std::collections::HashMap::new();
1384 if let Some(meta) = meta {
1385 for (index, col) in meta.row_type.iter().enumerate() {
1386 for &name in names {
1387 if col.name.eq_ignore_ascii_case(name) {
1388 indices.insert(name.to_string(), index);
1389 }
1390 }
1391 }
1392 }
1393 indices
1394}