1use anyhow::{Context as _, Result};
2use flate2::read::GzDecoder;
3use gpui::BackgroundExecutor;
4use http_client::{AsyncBody, HttpClient, Method, Request};
5use indoc::indoc;
6use serde::Deserialize;
7use serde_json::{Value as JsonValue, json};
8use std::fmt::Write as _;
9use std::io::Read;
10use std::sync::Arc;
11use std::time::Duration;
12use telemetry_events::EditPredictionRating;
13
14use zeta_prompt::{ZetaFormat, ZetaPromptInput, excerpt_range_for_format};
15
16use crate::example::Example;
17use crate::progress::{InfoStyle, Progress, Step};
18const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment";
19use edit_prediction::example_spec::{ExampleSpec, TelemetrySource};
20
21pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
22pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
23const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
24const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
25const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
26const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
27
28/// Minimum Zed version for filtering captured examples.
29/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
30/// where `zed_version >= 0.224.1`.
31#[derive(Clone, Copy, Debug)]
32pub struct MinCaptureVersion {
33 pub minor: u32,
34 pub patch: u32,
35}
36
37const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
38const SETTLED_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
39pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
40pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
41
42/// Parse an input token of the form `captured-after:{timestamp}`.
43pub fn parse_captured_after_input(input: &str) -> Option<&str> {
44 input.strip_prefix("captured-after:")
45}
46
47/// Parse an input token of the form `rejected-after:{timestamp}`.
48pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
49 input.strip_prefix("rejected-after:")
50}
51
52/// Parse an input token of the form `requested-after:{timestamp}`.
53pub fn parse_requested_after_input(input: &str) -> Option<&str> {
54 input.strip_prefix("requested-after:")
55}
56
57/// Parse an input token of the form `settled-after:{timestamp}`.
58pub fn parse_settled_after_input(input: &str) -> Option<&str> {
59 input.strip_prefix("settled-after:")
60}
61
62/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
63/// or `rated-negative-after:{timestamp}`.
64/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
65pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
66 if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
67 Some((timestamp, Some(EditPredictionRating::Positive)))
68 } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
69 Some((timestamp, Some(EditPredictionRating::Negative)))
70 } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
71 Some((timestamp, None))
72 } else {
73 None
74 }
75}
76
77#[derive(Debug, Clone, Deserialize)]
78#[serde(rename_all = "camelCase")]
79pub(crate) struct SnowflakeStatementResponse {
80 #[serde(default)]
81 pub(crate) data: Vec<Vec<JsonValue>>,
82 #[serde(default)]
83 pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
84 #[serde(default)]
85 pub(crate) code: Option<String>,
86 #[serde(default)]
87 pub(crate) message: Option<String>,
88 #[serde(default)]
89 pub(crate) statement_handle: Option<String>,
90}
91
92#[derive(Debug, Clone, Deserialize)]
93#[serde(rename_all = "camelCase")]
94pub(crate) struct SnowflakeResultSetMetaData {
95 #[serde(default, rename = "rowType")]
96 row_type: Vec<SnowflakeColumnMeta>,
97 #[serde(default)]
98 num_rows: Option<i64>,
99 #[serde(default)]
100 partition_info: Vec<SnowflakePartitionInfo>,
101}
102
103#[derive(Debug, Clone, Deserialize)]
104#[serde(rename_all = "camelCase")]
105struct SnowflakePartitionInfo {}
106
107#[derive(Debug, Clone, Deserialize)]
108struct SnowflakeColumnMeta {
109 #[serde(default)]
110 name: String,
111}
112
113async fn run_sql_with_polling(
114 http_client: Arc<dyn HttpClient>,
115 base_url: &str,
116 token: &str,
117 request: &serde_json::Value,
118 step_progress: &crate::progress::StepProgress,
119 background_executor: BackgroundExecutor,
120) -> Result<SnowflakeStatementResponse> {
121 let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
122
123 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
124 let statement_handle = response
125 .statement_handle
126 .as_ref()
127 .context("async query response missing statementHandle")?
128 .clone();
129
130 for attempt in 1..=MAX_POLL_ATTEMPTS {
131 step_progress.set_substatus(format!("polling ({attempt})"));
132
133 background_executor.timer(POLL_INTERVAL).await;
134
135 response =
136 fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
137
138 if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
139 break;
140 }
141 }
142
143 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
144 anyhow::bail!(
145 "query still running after {} poll attempts ({} seconds)",
146 MAX_POLL_ATTEMPTS,
147 MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
148 );
149 }
150 }
151
152 Ok(response)
153}
154
155struct SnowflakeConfig {
156 token: String,
157 base_url: String,
158 role: Option<String>,
159}
160
161async fn fetch_examples_with_query(
162 http_client: Arc<dyn HttpClient>,
163 step_progress: &crate::progress::StepProgress,
164 background_executor: BackgroundExecutor,
165 statement: &str,
166 bindings: JsonValue,
167 timeout_seconds: u64,
168 required_columns: &[&str],
169 parse_response: for<'a> fn(
170 &'a SnowflakeStatementResponse,
171 &'a std::collections::HashMap<String, usize>,
172 ) -> Result<Box<dyn Iterator<Item = Example> + 'a>>,
173) -> Result<Vec<Example>> {
174 let snowflake = SnowflakeConfig {
175 token: std::env::var("EP_SNOWFLAKE_API_KEY")
176 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?,
177 base_url: std::env::var("EP_SNOWFLAKE_BASE_URL").context(
178 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
179 )?,
180 role: std::env::var("EP_SNOWFLAKE_ROLE").ok(),
181 };
182 let request = json!({
183 "statement": statement,
184 "timeout": timeout_seconds,
185 "database": "EVENTS",
186 "schema": "PUBLIC",
187 "warehouse": "DBT",
188 "role": snowflake.role.as_deref(),
189 "bindings": bindings
190 });
191
192 let response = run_sql_with_polling(
193 http_client.clone(),
194 &snowflake.base_url,
195 &snowflake.token,
196 &request,
197 step_progress,
198 background_executor,
199 )
200 .await?;
201
202 let total_rows = response
203 .result_set_meta_data
204 .as_ref()
205 .and_then(|meta| meta.num_rows)
206 .unwrap_or(response.data.len() as i64);
207 let partition_count = response
208 .result_set_meta_data
209 .as_ref()
210 .map(|meta| meta.partition_info.len())
211 .unwrap_or(1)
212 .max(1);
213
214 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
215 step_progress.set_substatus("parsing");
216
217 let column_indices = get_column_indices(&response.result_set_meta_data, required_columns);
218
219 let mut parsed_examples = Vec::with_capacity(total_rows as usize);
220 parsed_examples.extend(parse_response(&response, &column_indices)?);
221
222 if partition_count > 1 {
223 let statement_handle = response
224 .statement_handle
225 .as_ref()
226 .context("response has multiple partitions but no statementHandle")?;
227
228 for partition in 1..partition_count {
229 step_progress.set_substatus(format!(
230 "fetching partition {}/{}",
231 partition + 1,
232 partition_count
233 ));
234
235 let partition_response = fetch_partition(
236 http_client.clone(),
237 &snowflake.base_url,
238 &snowflake.token,
239 statement_handle,
240 partition,
241 )
242 .await?;
243
244 parsed_examples.extend(parse_response(&partition_response, &column_indices)?);
245 }
246 }
247
248 step_progress.set_substatus("done");
249 Ok(parsed_examples)
250}
251
252pub(crate) async fn fetch_partition(
253 http_client: Arc<dyn HttpClient>,
254 base_url: &str,
255 token: &str,
256 statement_handle: &str,
257 partition: usize,
258) -> Result<SnowflakeStatementResponse> {
259 let url = format!(
260 "{}/api/v2/statements/{}?partition={}",
261 base_url.trim_end_matches('/'),
262 statement_handle,
263 partition
264 );
265
266 let http_request = Request::builder()
267 .method(Method::GET)
268 .uri(url.as_str())
269 .header("Authorization", format!("Bearer {token}"))
270 .header(
271 "X-Snowflake-Authorization-Token-Type",
272 "PROGRAMMATIC_ACCESS_TOKEN",
273 )
274 .header("Accept", "application/json")
275 .header("Accept-Encoding", "gzip")
276 .header("User-Agent", "edit_prediction_cli")
277 .body(AsyncBody::empty())?;
278
279 let response = http_client
280 .send(http_request)
281 .await
282 .context("failed to send partition request to Snowflake SQL API")?;
283
284 let status = response.status();
285 let content_encoding = response
286 .headers()
287 .get("content-encoding")
288 .and_then(|v| v.to_str().ok())
289 .map(|s| s.to_lowercase());
290
291 let body_bytes = {
292 use futures::AsyncReadExt as _;
293
294 let mut body = response.into_body();
295 let mut bytes = Vec::new();
296 body.read_to_end(&mut bytes)
297 .await
298 .context("failed to read Snowflake SQL API partition response body")?;
299 bytes
300 };
301
302 let body_bytes = if content_encoding.as_deref() == Some("gzip") {
303 let mut decoder = GzDecoder::new(&body_bytes[..]);
304 let mut decompressed = Vec::new();
305 decoder
306 .read_to_end(&mut decompressed)
307 .context("failed to decompress gzip response")?;
308 decompressed
309 } else {
310 body_bytes
311 };
312
313 if !status.is_success() && status.as_u16() != 202 {
314 let body_text = String::from_utf8_lossy(&body_bytes);
315 anyhow::bail!(
316 "snowflake sql api partition request http {}: {}",
317 status.as_u16(),
318 body_text
319 );
320 }
321
322 if body_bytes.is_empty() {
323 anyhow::bail!(
324 "snowflake sql api partition {} returned empty response body (http {})",
325 partition,
326 status.as_u16()
327 );
328 }
329
330 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
331 let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
332 format!(
333 "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
334 partition,
335 status.as_u16(),
336 body_preview
337 )
338 })
339}
340
341pub(crate) async fn run_sql(
342 http_client: Arc<dyn HttpClient>,
343 base_url: &str,
344 token: &str,
345 request: &serde_json::Value,
346) -> Result<SnowflakeStatementResponse> {
347 let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
348
349 let request_body =
350 serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
351
352 let http_request = Request::builder()
353 .method(Method::POST)
354 .uri(url.as_str())
355 .header("Authorization", format!("Bearer {token}"))
356 .header(
357 "X-Snowflake-Authorization-Token-Type",
358 "PROGRAMMATIC_ACCESS_TOKEN",
359 )
360 .header("Content-Type", "application/json")
361 .header("Accept", "application/json")
362 .header("User-Agent", "edit_prediction_cli")
363 .body(AsyncBody::from(request_body.clone()))?;
364
365 let response = http_client
366 .send(http_request)
367 .await
368 .context("failed to send request to Snowflake SQL API")?;
369
370 let status = response.status();
371 let body_bytes = {
372 use futures::AsyncReadExt as _;
373
374 let mut body = response.into_body();
375 let mut bytes = Vec::new();
376 body.read_to_end(&mut bytes)
377 .await
378 .context("failed to read Snowflake SQL API response body")?;
379 bytes
380 };
381
382 if !status.is_success() && status.as_u16() != 202 {
383 let body_text = String::from_utf8_lossy(&body_bytes);
384 anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
385 }
386
387 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
388 .context("failed to parse Snowflake SQL API response JSON")
389}
390
391pub async fn fetch_rejected_examples_after(
392 http_client: Arc<dyn HttpClient>,
393 after_timestamps: &[String],
394 max_rows_per_timestamp: usize,
395 offset: usize,
396 background_executor: BackgroundExecutor,
397 min_capture_version: Option<MinCaptureVersion>,
398) -> Result<Vec<Example>> {
399 if after_timestamps.is_empty() {
400 return Ok(Vec::new());
401 }
402
403 let progress = Progress::global();
404
405 let mut all_examples = Vec::new();
406
407 for after_date in after_timestamps.iter() {
408 let step_progress_name = format!("rejected>{after_date}");
409 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
410 step_progress.set_substatus("querying");
411
412 let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
413 let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
414 let min_minor_str_ref = min_minor_str.as_deref();
415 let min_patch_str_ref = min_patch_str.as_deref();
416
417 let statement = indoc! {r#"
418 SELECT
419 req.event_properties:request_id::string AS request_id,
420 req.device_id::string AS device_id,
421 req.time::string AS time,
422 req.event_properties:input AS input,
423 req.event_properties:prompt::string AS prompt,
424 req.event_properties:output::string AS output,
425 rej.event_properties:was_shown::boolean AS was_shown,
426 rej.event_properties:reason::string AS reason,
427 req.event_properties:zed_version::string AS zed_version
428 FROM events req
429 INNER JOIN events rej
430 ON req.event_properties:request_id = rej.event_properties:request_id
431 WHERE req.event_type = ?
432 AND rej.event_type = ?
433 AND req.event_properties:version = 'V3'
434 AND rej.event_properties:was_shown = true
435 AND req.event_properties:input:can_collect_data = true
436 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
437 AND (? IS NULL OR (
438 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
439 OR (
440 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
441 AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
442 )
443 ))
444 ORDER BY req.time ASC
445 LIMIT ?
446 OFFSET ?
447 "#};
448
449 let bindings = json!({
450 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
451 "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
452 "3": { "type": "TEXT", "value": after_date },
453 "4": { "type": "FIXED", "value": min_minor_str_ref },
454 "5": { "type": "FIXED", "value": min_minor_str_ref },
455 "6": { "type": "FIXED", "value": min_minor_str_ref },
456 "7": { "type": "FIXED", "value": min_patch_str_ref },
457 "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
458 "9": { "type": "FIXED", "value": offset.to_string() }
459 });
460
461 let examples = fetch_examples_with_query(
462 http_client.clone(),
463 &step_progress,
464 background_executor.clone(),
465 statement,
466 bindings,
467 DEFAULT_STATEMENT_TIMEOUT_SECONDS,
468 &[
469 "request_id",
470 "device_id",
471 "time",
472 "input",
473 "prompt",
474 "output",
475 "was_shown",
476 "reason",
477 "zed_version",
478 ],
479 rejected_examples_from_response,
480 )
481 .await?;
482
483 all_examples.extend(examples);
484 }
485
486 Ok(all_examples)
487}
488
489pub async fn fetch_requested_examples_after(
490 http_client: Arc<dyn HttpClient>,
491 after_timestamps: &[String],
492 max_rows_per_timestamp: usize,
493 offset: usize,
494 background_executor: BackgroundExecutor,
495 min_capture_version: Option<MinCaptureVersion>,
496) -> Result<Vec<Example>> {
497 if after_timestamps.is_empty() {
498 return Ok(Vec::new());
499 }
500
501 let progress = Progress::global();
502
503 let mut all_examples = Vec::new();
504
505 for after_date in after_timestamps.iter() {
506 let step_progress_name = format!("requested>{after_date}");
507 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
508 step_progress.set_substatus("querying");
509
510 let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
511 let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
512 let min_minor_str_ref = min_minor_str.as_deref();
513 let min_patch_str_ref = min_patch_str.as_deref();
514
515 let statement = indoc! {r#"
516 SELECT
517 req.event_properties:request_id::string AS request_id,
518 req.device_id::string AS device_id,
519 req.time::string AS time,
520 req.event_properties:input AS input,
521 req.event_properties:zed_version::string AS zed_version
522 FROM events req
523 WHERE req.event_type = ?
524 AND req.event_properties:version = 'V3'
525 AND req.event_properties:input:can_collect_data = true
526 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
527 AND (? IS NULL OR (
528 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
529 OR (
530 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
531 AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
532 )
533 ))
534 ORDER BY req.time ASC
535 LIMIT ?
536 OFFSET ?
537 "#};
538
539 let bindings = json!({
540 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
541 "2": { "type": "TEXT", "value": after_date },
542 "3": { "type": "FIXED", "value": min_minor_str_ref },
543 "4": { "type": "FIXED", "value": min_minor_str_ref },
544 "5": { "type": "FIXED", "value": min_minor_str_ref },
545 "6": { "type": "FIXED", "value": min_patch_str_ref },
546 "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
547 "8": { "type": "FIXED", "value": offset.to_string() }
548 });
549
550 let examples = fetch_examples_with_query(
551 http_client.clone(),
552 &step_progress,
553 background_executor.clone(),
554 statement,
555 bindings,
556 DEFAULT_STATEMENT_TIMEOUT_SECONDS,
557 &["request_id", "device_id", "time", "input", "zed_version"],
558 requested_examples_from_response,
559 )
560 .await?;
561
562 all_examples.extend(examples);
563 }
564
565 Ok(all_examples)
566}
567
568pub async fn fetch_settled_examples_after(
569 http_client: Arc<dyn HttpClient>,
570 after_timestamps: &[String],
571 max_rows_per_timestamp: usize,
572 offset: usize,
573 background_executor: BackgroundExecutor,
574 min_capture_version: Option<MinCaptureVersion>,
575) -> Result<Vec<Example>> {
576 if after_timestamps.is_empty() {
577 return Ok(Vec::new());
578 }
579
580 let progress = Progress::global();
581
582 let mut all_examples = Vec::new();
583
584 for after_date in after_timestamps.iter() {
585 let step_progress_name = format!("settled>{after_date}");
586 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
587 step_progress.set_substatus("querying");
588
589 let _ = min_capture_version;
590
591 let statement = indoc! {r#"
592 WITH requested AS (
593 SELECT
594 req.event_properties:request_id::string AS request_id,
595 req.device_id::string AS device_id,
596 req.time AS req_time,
597 req.time::string AS time,
598 req.event_properties:input AS input,
599 req.event_properties:format::string AS requested_format,
600 req.event_properties:output::string AS requested_output,
601 req.event_properties:zed_version::string AS zed_version
602 FROM events req
603 WHERE req.event_type = ?
604 AND req.event_properties:version = 'V3'
605 AND req.event_properties:input:can_collect_data = true
606 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
607 )
608 SELECT
609 req.request_id AS request_id,
610 req.device_id AS device_id,
611 req.time AS time,
612 req.input AS input,
613 req.requested_output AS requested_output,
614 settled.event_properties:settled_editable_region::string AS settled_editable_region,
615 req.requested_format AS requested_format,
616 req.zed_version AS zed_version
617 FROM requested req
618 INNER JOIN events settled
619 ON req.request_id = settled.event_properties:request_id::string
620 WHERE settled.event_type = ?
621 ORDER BY req.req_time ASC
622 LIMIT ?
623 OFFSET ?
624 "#};
625
626 let bindings = json!({
627 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
628 "2": { "type": "TEXT", "value": after_date },
629 "3": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT },
630 "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
631 "5": { "type": "FIXED", "value": offset.to_string() }
632 });
633
634 let examples = fetch_examples_with_query(
635 http_client.clone(),
636 &step_progress,
637 background_executor.clone(),
638 statement,
639 bindings,
640 SETTLED_STATEMENT_TIMEOUT_SECONDS,
641 &[
642 "request_id",
643 "device_id",
644 "time",
645 "input",
646 "requested_output",
647 "settled_editable_region",
648 "requested_format",
649 "zed_version",
650 ],
651 settled_examples_from_response,
652 )
653 .await?;
654
655 all_examples.extend(examples);
656 }
657
658 Ok(all_examples)
659}
660
661pub async fn fetch_rated_examples_after(
662 http_client: Arc<dyn HttpClient>,
663 inputs: &[(String, Option<EditPredictionRating>)],
664 max_rows_per_timestamp: usize,
665 offset: usize,
666 background_executor: BackgroundExecutor,
667 _min_capture_version: Option<MinCaptureVersion>,
668) -> Result<Vec<Example>> {
669 if inputs.is_empty() {
670 return Ok(Vec::new());
671 }
672
673 let progress = Progress::global();
674
675 let mut all_examples = Vec::new();
676
677 for (after_date, rating_filter) in inputs.iter() {
678 let filter_label = match rating_filter {
679 None => "",
680 Some(EditPredictionRating::Positive) => ":positive",
681 Some(EditPredictionRating::Negative) => ":negative",
682 };
683 let step_progress_name = format!("rated{filter_label}>{after_date}");
684 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
685 step_progress.set_substatus("querying");
686
687 let rating_value = rating_filter.as_ref().map(|rating| match rating {
688 EditPredictionRating::Positive => "Positive",
689 EditPredictionRating::Negative => "Negative",
690 });
691
692 let statement = indoc! {r#"
693 SELECT
694 rated.event_properties:request_id::string AS request_id,
695 rated.event_properties:inputs AS inputs,
696 rated.event_properties:output::string AS output,
697 rated.event_properties:rating::string AS rating,
698 rated.event_properties:feedback::string AS feedback,
699 rated.device_id::string AS device_id,
700 rated.time::string AS time,
701 deploy.event_properties:experiment_name::string AS experiment_name,
702 deploy.event_properties:environment::string AS environment,
703 rated.event_properties:zed_version::string AS zed_version
704 FROM events rated
705 LEFT JOIN events req
706 ON rated.event_properties:request_id::string = req.event_properties:request_id::string
707 AND req.event_type = ?
708 LEFT JOIN events deploy
709 ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string
710 AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string
711 AND deploy.event_type = ?
712 WHERE rated.event_type = ?
713 AND (? IS NULL OR rated.event_properties:rating::string = ?)
714 AND rated.time > TRY_TO_TIMESTAMP_NTZ(?)
715 AND rated.event_properties:inputs IS NOT NULL
716 AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
717 AND rated.event_properties:output IS NOT NULL
718 AND rated.event_properties:can_collect_data = true
719 ORDER BY rated.time ASC
720 LIMIT ?
721 OFFSET ?
722 "#};
723
724 let bindings = json!({
725 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
726 "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
727 "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
728 "4": { "type": "TEXT", "value": rating_value },
729 "5": { "type": "TEXT", "value": rating_value },
730 "6": { "type": "TEXT", "value": after_date },
731 "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
732 "8": { "type": "FIXED", "value": offset.to_string() }
733 });
734
735 let examples = fetch_examples_with_query(
736 http_client.clone(),
737 &step_progress,
738 background_executor.clone(),
739 statement,
740 bindings,
741 DEFAULT_STATEMENT_TIMEOUT_SECONDS,
742 &[
743 "request_id",
744 "inputs",
745 "output",
746 "rating",
747 "feedback",
748 "device_id",
749 "time",
750 "experiment_name",
751 "environment",
752 "zed_version",
753 ],
754 rated_examples_from_response,
755 )
756 .await?;
757
758 all_examples.extend(examples);
759 }
760
761 Ok(all_examples)
762}
763
764fn rated_examples_from_response<'a>(
765 response: &'a SnowflakeStatementResponse,
766 column_indices: &'a std::collections::HashMap<String, usize>,
767) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
768 if let Some(code) = &response.code {
769 if code != SNOWFLAKE_SUCCESS_CODE {
770 anyhow::bail!(
771 "snowflake sql api returned error code={code} message={}",
772 response.message.as_deref().unwrap_or("<no message>")
773 );
774 }
775 }
776
777 let iter = response
778 .data
779 .iter()
780 .enumerate()
781 .filter_map(move |(row_index, data_row)| {
782 let get_string = |name: &str| -> Option<String> {
783 let index = column_indices.get(name).copied()?;
784 match data_row.get(index)? {
785 JsonValue::String(s) => Some(s.clone()),
786 JsonValue::Null => None,
787 other => Some(other.to_string()),
788 }
789 };
790
791 let get_json = |name: &str| -> Option<JsonValue> {
792 let index = column_indices.get(name).copied()?;
793 let value = data_row.get(index)?;
794 if value.is_null() {
795 return None;
796 }
797 match value {
798 JsonValue::String(s) => serde_json::from_str(s).ok(),
799 other => Some(other.clone()),
800 }
801 };
802
803 let request_id = get_string("request_id");
804 let inputs_json = get_json("inputs");
805 let inputs: Option<ZetaPromptInput> = match &inputs_json {
806 Some(v) => match serde_json::from_value(v.clone()) {
807 Ok(parsed) => Some(parsed),
808 Err(e) => {
809 log::warn!(
810 "skipping row {row_index}: failed to parse inputs - {e}",
811 );
812 return None;
813 }
814 },
815 None => None,
816 };
817 let output = get_string("output");
818 let rating = get_string("rating");
819 let feedback = get_string("feedback").unwrap_or_default();
820 let device_id = get_string("device_id");
821 let time = get_string("time");
822 let experiment_name = get_string("experiment_name");
823 let environment = get_string("environment");
824 let zed_version = get_string("zed_version");
825
826 match (inputs, output.clone(), rating.clone(), device_id.clone(), time.clone()) {
827 (Some(inputs), Some(output), Some(rating), Some(device_id), Some(time)) => {
828 Some(build_rated_example(
829 request_id,
830 device_id,
831 time,
832 inputs,
833 output,
834 rating,
835 feedback,
836 experiment_name,
837 environment,
838 zed_version,
839 ))
840 }
841 _ => {
842 log::warn!(
843 "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} device_id={:?} time={:?}",
844 inputs_json.is_some(),
845 output.is_some(),
846 rating.is_some(),
847 device_id.is_some(),
848 time.is_some(),
849 );
850 None
851 }
852 }
853 });
854
855 Ok(Box::new(iter))
856}
857
858fn build_rated_example(
859 request_id: Option<String>,
860 device_id: String,
861 time: String,
862 input: ZetaPromptInput,
863 output: String,
864 rating: String,
865 feedback: String,
866 experiment_name: Option<String>,
867 environment: Option<String>,
868 zed_version: Option<String>,
869) -> Example {
870 let parsed_rating = if rating == "Positive" {
871 EditPredictionRating::Positive
872 } else {
873 EditPredictionRating::Negative
874 };
875 let is_positive = parsed_rating == EditPredictionRating::Positive;
876 let request_id = request_id.unwrap_or_else(|| format!("rated-{}-{}", device_id, time));
877
878 let mut tags = Vec::with_capacity(3);
879 tags.push(if is_positive {
880 "rated:positive".to_string()
881 } else {
882 "rated:negative".to_string()
883 });
884 if let Some(experiment) = experiment_name {
885 tags.push(format!("experiment:{experiment}"));
886 }
887 if let Some(env) = environment {
888 tags.push(format!("environment:{env}"));
889 }
890
891 let mut example =
892 build_example_from_snowflake(request_id, device_id, time, input, tags, None, zed_version);
893
894 example.spec.rating = Some(parsed_rating);
895
896 if !feedback.is_empty() {
897 example
898 .spec
899 .human_feedback
900 .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
901 }
902
903 if is_positive {
904 example.spec.expected_patches = vec![output];
905 } else {
906 example.spec.rejected_patch = Some(output);
907 }
908
909 example
910}
911
912fn requested_examples_from_response<'a>(
913 response: &'a SnowflakeStatementResponse,
914 column_indices: &'a std::collections::HashMap<String, usize>,
915) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
916 if let Some(code) = &response.code {
917 if code != SNOWFLAKE_SUCCESS_CODE {
918 anyhow::bail!(
919 "snowflake sql api returned error code={code} message={}",
920 response.message.as_deref().unwrap_or("<no message>")
921 );
922 }
923 }
924
925 let iter = response
926 .data
927 .iter()
928 .enumerate()
929 .filter_map(move |(row_index, data_row)| {
930 let get_string = |name: &str| -> Option<String> {
931 let index = column_indices.get(name).copied()?;
932 match data_row.get(index)? {
933 JsonValue::String(s) => Some(s.clone()),
934 JsonValue::Null => None,
935 other => Some(other.to_string()),
936 }
937 };
938
939 let get_json = |name: &str| -> Option<JsonValue> {
940 let index = column_indices.get(name).copied()?;
941 let value = data_row.get(index)?;
942 if value.is_null() {
943 return None;
944 }
945 match value {
946 JsonValue::String(s) => serde_json::from_str(s).ok(),
947 other => Some(other.clone()),
948 }
949 };
950
951 let request_id_str = get_string("request_id");
952 let device_id = get_string("device_id");
953 let time = get_string("time");
954 let input_json = get_json("input");
955 let input: Option<ZetaPromptInput> =
956 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
957 let zed_version = get_string("zed_version");
958
959 match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
960 (Some(request_id), Some(device_id), Some(time), Some(input)) => {
961 Some(build_example_from_snowflake(
962 request_id,
963 device_id,
964 time,
965 input,
966 vec!["requested".to_string()],
967 None,
968 zed_version,
969 ))
970 }
971 _ => {
972 log::warn!(
973 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
974 request_id_str.is_some(),
975 device_id.is_some(),
976 time.is_some(),
977 input_json.is_some(),
978 );
979 None
980 }
981 }
982 });
983
984 Ok(Box::new(iter))
985}
986
987fn settled_examples_from_response<'a>(
988 response: &'a SnowflakeStatementResponse,
989 column_indices: &'a std::collections::HashMap<String, usize>,
990) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
991 if let Some(code) = &response.code {
992 if code != SNOWFLAKE_SUCCESS_CODE {
993 anyhow::bail!(
994 "snowflake sql api returned error code={code} message={}",
995 response.message.as_deref().unwrap_or("<no message>")
996 );
997 }
998 }
999
1000 let iter = response
1001 .data
1002 .iter()
1003 .enumerate()
1004 .filter_map(move |(row_index, data_row)| {
1005 let get_value = |name: &str| -> Option<JsonValue> {
1006 let index = column_indices.get(name).copied()?;
1007 let value = data_row.get(index)?;
1008 if value.is_null() {
1009 None
1010 } else {
1011 Some(value.clone())
1012 }
1013 };
1014
1015 let get_string = |name: &str| -> Option<String> {
1016 match get_value(name)? {
1017 JsonValue::String(s) => Some(s),
1018 other => Some(other.to_string()),
1019 }
1020 };
1021
1022 let parse_json_value = |_: &str, raw: Option<&JsonValue>| -> Option<JsonValue> {
1023 let value = raw?;
1024 match value {
1025 JsonValue::String(s) => serde_json::from_str::<JsonValue>(s).ok(),
1026 other => Some(other.clone()),
1027 }
1028 };
1029
1030 let request_id_str = get_string("request_id");
1031 let device_id = get_string("device_id");
1032 let time = get_string("time");
1033 let input_raw = get_value("input");
1034 let input_json = parse_json_value("input", input_raw.as_ref());
1035 let input: Option<ZetaPromptInput> = input_json
1036 .as_ref()
1037 .and_then(|parsed| serde_json::from_value(parsed.clone()).ok());
1038 let requested_output = get_string("requested_output");
1039 let settled_editable_region = get_string("settled_editable_region");
1040 let requested_format =
1041 get_string("requested_format").and_then(|s| ZetaFormat::parse(&s).ok());
1042 let zed_version = get_string("zed_version");
1043
1044 match (
1045 request_id_str.clone(),
1046 device_id.clone(),
1047 time.clone(),
1048 input.clone(),
1049 requested_output.clone(),
1050 settled_editable_region.clone(),
1051 requested_format,
1052 ) {
1053 (
1054 Some(request_id),
1055 Some(device_id),
1056 Some(time),
1057 Some(input),
1058 Some(requested_output),
1059 Some(settled_editable_region),
1060 Some(requested_format),
1061 ) => Some(build_settled_example(
1062 request_id,
1063 device_id,
1064 time,
1065 input,
1066 requested_output,
1067 settled_editable_region,
1068 requested_format,
1069 zed_version,
1070 )),
1071 _ => {
1072 let mut missing_fields = Vec::new();
1073
1074 if request_id_str.is_none() {
1075 missing_fields.push("request_id");
1076 }
1077 if device_id.is_none() {
1078 missing_fields.push("device_id");
1079 }
1080 if time.is_none() {
1081 missing_fields.push("time");
1082 }
1083 if input_raw.is_none() || input_json.is_none() || input.is_none() {
1084 missing_fields.push("input");
1085 }
1086 if requested_output.is_none() {
1087 missing_fields.push("requested_output");
1088 }
1089 if settled_editable_region.is_none() {
1090 missing_fields.push("settled_editable_region");
1091 }
1092 if requested_format.is_none() {
1093 missing_fields.push("requested_format");
1094 }
1095
1096 log::warn!(
1097 "skipping settled row {row_index}: [{}]",
1098 missing_fields.join(", "),
1099 );
1100 None
1101 }
1102 }
1103 });
1104
1105 Ok(Box::new(iter))
1106}
1107
1108fn build_settled_example(
1109 request_id: String,
1110 device_id: String,
1111 time: String,
1112 input: ZetaPromptInput,
1113 requested_output: String,
1114 settled_editable_region: String,
1115 requested_format: ZetaFormat,
1116 zed_version: Option<String>,
1117) -> Example {
1118 let requested_editable_range = input
1119 .excerpt_ranges
1120 .as_ref()
1121 .map(|ranges| excerpt_range_for_format(requested_format, ranges).0)
1122 .unwrap_or_else(|| input.editable_range_in_excerpt.clone());
1123
1124 let base_cursor_excerpt = input.cursor_excerpt.to_string();
1125
1126 let requested_range_is_valid = requested_editable_range.start <= requested_editable_range.end
1127 && requested_editable_range.end <= base_cursor_excerpt.len();
1128 let mut example = build_example_from_snowflake(
1129 request_id.clone(),
1130 device_id,
1131 time,
1132 input,
1133 vec!["settled".to_string()],
1134 None,
1135 zed_version,
1136 );
1137
1138 if !requested_range_is_valid {
1139 log::warn!(
1140 "skipping malformed requested range for request {}: requested={:?} (base_len={})",
1141 request_id,
1142 requested_editable_range,
1143 base_cursor_excerpt.len(),
1144 );
1145 return example;
1146 }
1147
1148 let settled_replacement = settled_editable_region.as_str();
1149 let rejected_patch = build_output_patch(
1150 &example.spec.cursor_path,
1151 &base_cursor_excerpt,
1152 &requested_editable_range,
1153 &requested_output,
1154 );
1155 let expected_patch = build_output_patch(
1156 &example.spec.cursor_path,
1157 &base_cursor_excerpt,
1158 &requested_editable_range,
1159 settled_replacement,
1160 );
1161
1162 example.spec.expected_patches = vec![expected_patch];
1163 example.spec.rejected_patch = Some(rejected_patch);
1164 example
1165}
1166
1167fn rejected_examples_from_response<'a>(
1168 response: &'a SnowflakeStatementResponse,
1169 column_indices: &'a std::collections::HashMap<String, usize>,
1170) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1171 if let Some(code) = &response.code {
1172 if code != SNOWFLAKE_SUCCESS_CODE {
1173 anyhow::bail!(
1174 "snowflake sql api returned error code={code} message={}",
1175 response.message.as_deref().unwrap_or("<no message>")
1176 );
1177 }
1178 }
1179
1180 let iter = response
1181 .data
1182 .iter()
1183 .enumerate()
1184 .filter_map(move |(row_index, data_row)| {
1185 let get_string = |name: &str| -> Option<String> {
1186 let index = column_indices.get(name).copied()?;
1187 match data_row.get(index)? {
1188 JsonValue::String(s) => Some(s.clone()),
1189 JsonValue::Null => None,
1190 other => Some(other.to_string()),
1191 }
1192 };
1193
1194 let get_json = |name: &str| -> Option<JsonValue> {
1195 let index = column_indices.get(name).copied()?;
1196 let value = data_row.get(index)?;
1197 if value.is_null() {
1198 return None;
1199 }
1200 match value {
1201 JsonValue::String(s) => serde_json::from_str(s).ok(),
1202 other => Some(other.clone()),
1203 }
1204 };
1205
1206 let get_bool = |name: &str| -> Option<bool> {
1207 let index = column_indices.get(name).copied()?;
1208 match data_row.get(index)? {
1209 JsonValue::Bool(b) => Some(*b),
1210 JsonValue::String(s) => s.parse().ok(),
1211 _ => None,
1212 }
1213 };
1214
1215 let request_id_str = get_string("request_id");
1216 let device_id = get_string("device_id");
1217 let time = get_string("time");
1218 let input_json = get_json("input");
1219 let input: Option<ZetaPromptInput> =
1220 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1221 let output = get_string("output");
1222 let was_shown = get_bool("was_shown");
1223 let reason = get_string("reason");
1224 let zed_version = get_string("zed_version");
1225
1226 match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1227 (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1228 Some(build_rejected_example(
1229 request_id,
1230 device_id,
1231 time,
1232 input,
1233 output,
1234 was_shown,
1235 reason,
1236 zed_version,
1237 ))
1238 }
1239 _ => {
1240 log::warn!(
1241 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1242 request_id_str.is_some(),
1243 device_id.is_some(),
1244 time.is_some(),
1245 input_json.is_some(),
1246 output.is_some(),
1247 was_shown.is_some(),
1248 reason.is_some()
1249 );
1250 None
1251 }
1252 }
1253 });
1254
1255 Ok(Box::new(iter))
1256}
1257
1258fn build_rejected_example(
1259 request_id: String,
1260 device_id: String,
1261 time: String,
1262 input: ZetaPromptInput,
1263 output: String,
1264 was_shown: bool,
1265 reason: String,
1266 zed_version: Option<String>,
1267) -> Example {
1268 let rejected_patch = build_output_patch(
1269 &input.cursor_path,
1270 input.cursor_excerpt.as_ref(),
1271 &input.editable_range_in_excerpt,
1272 &output,
1273 );
1274 let mut example = build_example_from_snowflake(
1275 request_id,
1276 device_id,
1277 time,
1278 input,
1279 vec![format!("rejection:{}", reason.to_lowercase())],
1280 Some(RejectionInfo { reason, was_shown }),
1281 zed_version,
1282 );
1283 example.spec.rejected_patch = Some(rejected_patch);
1284 example
1285}
1286
1287struct RejectionInfo {
1288 reason: String,
1289 was_shown: bool,
1290}
1291
1292fn build_example_from_snowflake(
1293 request_id: String,
1294 device_id: String,
1295 time: String,
1296 input: ZetaPromptInput,
1297 tags: Vec<String>,
1298 rejection: Option<RejectionInfo>,
1299 zed_version: Option<String>,
1300) -> Example {
1301 let cursor_excerpt = input.cursor_excerpt.as_ref();
1302 let cursor_offset = input.cursor_offset_in_excerpt;
1303
1304 let mut edit_history = String::new();
1305 for event in &input.events {
1306 zeta_prompt::write_event(&mut edit_history, event);
1307 edit_history.push('\n');
1308 }
1309
1310 let (rejection_reason, was_shown) = match &rejection {
1311 Some(r) => (r.reason.clone(), r.was_shown),
1312 None => (String::new(), false),
1313 };
1314
1315 let spec = ExampleSpec {
1316 name: request_id.clone(),
1317 repository_url: String::new(),
1318 revision: String::new(),
1319 tags,
1320 reasoning: None,
1321 uncommitted_diff: String::new(),
1322 cursor_path: input.cursor_path.clone(),
1323 cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1324 edit_history,
1325 expected_patches: Vec::new(),
1326 rejected_patch: None,
1327 telemetry: Some(TelemetrySource {
1328 request_id,
1329 device_id,
1330 time,
1331 rejection_reason,
1332 was_shown,
1333 }),
1334 human_feedback: Vec::new(),
1335 rating: None,
1336 };
1337
1338 Example {
1339 spec,
1340 zed_version,
1341 prompt_inputs: Some(input),
1342 prompt: None,
1343 predictions: Vec::new(),
1344 score: Vec::new(),
1345 qa: Vec::new(),
1346 state: None,
1347 }
1348}
1349
1350fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1351 let before = &excerpt[..cursor_offset.min(excerpt.len())];
1352 let after = &excerpt[cursor_offset.min(excerpt.len())..];
1353 format!("{}[CURSOR_POSITION]{}", before, after)
1354}
1355
1356fn build_output_patch(
1357 cursor_path: &std::path::Path,
1358 cursor_excerpt: &str,
1359 editable_range: &std::ops::Range<usize>,
1360 model_output: &str,
1361) -> String {
1362 let old_text = &cursor_excerpt[editable_range.clone()];
1363
1364 let editable_start_row = cursor_excerpt[..editable_range.start]
1365 .chars()
1366 .filter(|&c| c == '\n')
1367 .count() as u32;
1368
1369 let diff_body = language::unified_diff_with_offsets(
1370 old_text,
1371 model_output,
1372 editable_start_row,
1373 editable_start_row,
1374 );
1375
1376 let mut patch = String::new();
1377 writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1378 writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1379 patch.push_str(&diff_body);
1380 patch
1381}
1382
1383pub(crate) fn get_column_indices(
1384 meta: &Option<SnowflakeResultSetMetaData>,
1385 names: &[&str],
1386) -> std::collections::HashMap<String, usize> {
1387 let mut indices = std::collections::HashMap::new();
1388 if let Some(meta) = meta {
1389 for (index, col) in meta.row_type.iter().enumerate() {
1390 for &name in names {
1391 if col.name.eq_ignore_ascii_case(name) {
1392 indices.insert(name.to_string(), index);
1393 }
1394 }
1395 }
1396 }
1397 indices
1398}