1use anyhow::{Context as _, Result};
2use flate2::read::GzDecoder;
3use gpui::BackgroundExecutor;
4use http_client::{AsyncBody, HttpClient, Method, Request};
5use indoc::indoc;
6use serde::Deserialize;
7use serde_json::{Value as JsonValue, json};
8use std::io::Read;
9use std::sync::Arc;
10use std::time::Duration;
11use telemetry_events::EditPredictionRating;
12
13use zeta_prompt::ZetaPromptInput;
14
15use crate::example::Example;
16use crate::progress::{InfoStyle, Progress, Step};
17const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment";
18use edit_prediction::example_spec::{ExampleSpec, TelemetrySource};
19use std::fmt::Write as _;
20
21pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
22pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
23const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
24const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
25const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
26
27/// Minimum Zed version for filtering captured examples.
28/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
29/// where `zed_version >= 0.224.1`.
30#[derive(Clone, Copy, Debug)]
31pub struct MinCaptureVersion {
32 pub minor: u32,
33 pub patch: u32,
34}
35
36const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
37pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
38pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
39
40/// Parse an input token of the form `captured-after:{timestamp}`.
41pub fn parse_captured_after_input(input: &str) -> Option<&str> {
42 input.strip_prefix("captured-after:")
43}
44
45/// Parse an input token of the form `rejected-after:{timestamp}`.
46pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
47 input.strip_prefix("rejected-after:")
48}
49
50/// Parse an input token of the form `requested-after:{timestamp}`.
51pub fn parse_requested_after_input(input: &str) -> Option<&str> {
52 input.strip_prefix("requested-after:")
53}
54
55/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
56/// or `rated-negative-after:{timestamp}`.
57/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
58pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
59 if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
60 Some((timestamp, Some(EditPredictionRating::Positive)))
61 } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
62 Some((timestamp, Some(EditPredictionRating::Negative)))
63 } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
64 Some((timestamp, None))
65 } else {
66 None
67 }
68}
69
70#[derive(Debug, Clone, Deserialize)]
71#[serde(rename_all = "camelCase")]
72pub(crate) struct SnowflakeStatementResponse {
73 #[serde(default)]
74 pub(crate) data: Vec<Vec<JsonValue>>,
75 #[serde(default)]
76 pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
77 #[serde(default)]
78 pub(crate) code: Option<String>,
79 #[serde(default)]
80 pub(crate) message: Option<String>,
81 #[serde(default)]
82 pub(crate) statement_handle: Option<String>,
83}
84
85#[derive(Debug, Clone, Deserialize)]
86#[serde(rename_all = "camelCase")]
87pub(crate) struct SnowflakeResultSetMetaData {
88 #[serde(default, rename = "rowType")]
89 row_type: Vec<SnowflakeColumnMeta>,
90 #[serde(default)]
91 num_rows: Option<i64>,
92 #[serde(default)]
93 partition_info: Vec<SnowflakePartitionInfo>,
94}
95
96#[derive(Debug, Clone, Deserialize)]
97#[serde(rename_all = "camelCase")]
98struct SnowflakePartitionInfo {}
99
100#[derive(Debug, Clone, Deserialize)]
101struct SnowflakeColumnMeta {
102 #[serde(default)]
103 name: String,
104}
105
106async fn run_sql_with_polling(
107 http_client: Arc<dyn HttpClient>,
108 base_url: &str,
109 token: &str,
110 request: &serde_json::Value,
111 step_progress: &crate::progress::StepProgress,
112 background_executor: BackgroundExecutor,
113) -> Result<SnowflakeStatementResponse> {
114 let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
115
116 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
117 let statement_handle = response
118 .statement_handle
119 .as_ref()
120 .context("async query response missing statementHandle")?
121 .clone();
122
123 for attempt in 1..=MAX_POLL_ATTEMPTS {
124 step_progress.set_substatus(format!("polling ({attempt})"));
125
126 background_executor.timer(POLL_INTERVAL).await;
127
128 response =
129 fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
130
131 if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
132 break;
133 }
134 }
135
136 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
137 anyhow::bail!(
138 "query still running after {} poll attempts ({} seconds)",
139 MAX_POLL_ATTEMPTS,
140 MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
141 );
142 }
143 }
144
145 Ok(response)
146}
147
148pub(crate) async fn fetch_partition(
149 http_client: Arc<dyn HttpClient>,
150 base_url: &str,
151 token: &str,
152 statement_handle: &str,
153 partition: usize,
154) -> Result<SnowflakeStatementResponse> {
155 let url = format!(
156 "{}/api/v2/statements/{}?partition={}",
157 base_url.trim_end_matches('/'),
158 statement_handle,
159 partition
160 );
161
162 let http_request = Request::builder()
163 .method(Method::GET)
164 .uri(url.as_str())
165 .header("Authorization", format!("Bearer {token}"))
166 .header(
167 "X-Snowflake-Authorization-Token-Type",
168 "PROGRAMMATIC_ACCESS_TOKEN",
169 )
170 .header("Accept", "application/json")
171 .header("Accept-Encoding", "gzip")
172 .header("User-Agent", "edit_prediction_cli")
173 .body(AsyncBody::empty())?;
174
175 let response = http_client
176 .send(http_request)
177 .await
178 .context("failed to send partition request to Snowflake SQL API")?;
179
180 let status = response.status();
181 let content_encoding = response
182 .headers()
183 .get("content-encoding")
184 .and_then(|v| v.to_str().ok())
185 .map(|s| s.to_lowercase());
186
187 let body_bytes = {
188 use futures::AsyncReadExt as _;
189
190 let mut body = response.into_body();
191 let mut bytes = Vec::new();
192 body.read_to_end(&mut bytes)
193 .await
194 .context("failed to read Snowflake SQL API partition response body")?;
195 bytes
196 };
197
198 let body_bytes = if content_encoding.as_deref() == Some("gzip") {
199 let mut decoder = GzDecoder::new(&body_bytes[..]);
200 let mut decompressed = Vec::new();
201 decoder
202 .read_to_end(&mut decompressed)
203 .context("failed to decompress gzip response")?;
204 decompressed
205 } else {
206 body_bytes
207 };
208
209 if !status.is_success() && status.as_u16() != 202 {
210 let body_text = String::from_utf8_lossy(&body_bytes);
211 anyhow::bail!(
212 "snowflake sql api partition request http {}: {}",
213 status.as_u16(),
214 body_text
215 );
216 }
217
218 if body_bytes.is_empty() {
219 anyhow::bail!(
220 "snowflake sql api partition {} returned empty response body (http {})",
221 partition,
222 status.as_u16()
223 );
224 }
225
226 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
227 let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
228 format!(
229 "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
230 partition,
231 status.as_u16(),
232 body_preview
233 )
234 })
235}
236
237pub(crate) async fn run_sql(
238 http_client: Arc<dyn HttpClient>,
239 base_url: &str,
240 token: &str,
241 request: &serde_json::Value,
242) -> Result<SnowflakeStatementResponse> {
243 let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
244
245 let request_body =
246 serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
247
248 let http_request = Request::builder()
249 .method(Method::POST)
250 .uri(url.as_str())
251 .header("Authorization", format!("Bearer {token}"))
252 .header(
253 "X-Snowflake-Authorization-Token-Type",
254 "PROGRAMMATIC_ACCESS_TOKEN",
255 )
256 .header("Content-Type", "application/json")
257 .header("Accept", "application/json")
258 .header("User-Agent", "edit_prediction_cli")
259 .body(AsyncBody::from(request_body.clone()))?;
260
261 let response = http_client
262 .send(http_request)
263 .await
264 .context("failed to send request to Snowflake SQL API")?;
265
266 let status = response.status();
267 let body_bytes = {
268 use futures::AsyncReadExt as _;
269
270 let mut body = response.into_body();
271 let mut bytes = Vec::new();
272 body.read_to_end(&mut bytes)
273 .await
274 .context("failed to read Snowflake SQL API response body")?;
275 bytes
276 };
277
278 if !status.is_success() && status.as_u16() != 202 {
279 let body_text = String::from_utf8_lossy(&body_bytes);
280 anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
281 }
282
283 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
284 .context("failed to parse Snowflake SQL API response JSON")
285}
286
287pub async fn fetch_rejected_examples_after(
288 http_client: Arc<dyn HttpClient>,
289 after_timestamps: &[String],
290 max_rows_per_timestamp: usize,
291 offset: usize,
292 background_executor: BackgroundExecutor,
293 min_capture_version: Option<MinCaptureVersion>,
294) -> Result<Vec<Example>> {
295 if after_timestamps.is_empty() {
296 return Ok(Vec::new());
297 }
298
299 let progress = Progress::global();
300
301 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
302 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
303 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
304 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
305 )?;
306 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
307
308 let mut all_examples = Vec::new();
309
310 for after_date in after_timestamps.iter() {
311 let step_progress_name = format!("rejected>{after_date}");
312 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
313 step_progress.set_substatus("querying");
314
315 // Join rejected events with their corresponding request events to get the full context.
316 // We filter for V3 sampling data which contains the structured input we need.
317 // We also filter for predictions that were actually shown to the user (was_shown = true)
318 // to focus on explicit user rejections rather than implicit cancellations.
319 let statement = indoc! {r#"
320 SELECT
321 req.event_properties:request_id::string AS request_id,
322 req.device_id::string AS device_id,
323 req.time::string AS time,
324 req.event_properties:input AS input,
325 req.event_properties:prompt::string AS prompt,
326 req.event_properties:output::string AS output,
327 rej.event_properties:was_shown::boolean AS was_shown,
328 rej.event_properties:reason::string AS reason,
329 req.event_properties:zed_version::string AS zed_version
330 FROM events req
331 INNER JOIN events rej
332 ON req.event_properties:request_id = rej.event_properties:request_id
333 WHERE req.event_type = ?
334 AND rej.event_type = ?
335 AND req.event_properties:version = 'V3'
336 AND rej.event_properties:was_shown = true
337 AND req.event_properties:input:can_collect_data = true
338 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
339 AND (? IS NULL OR (
340 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
341 OR (
342 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
343 AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
344 )
345 ))
346 ORDER BY req.time ASC
347 LIMIT ?
348 OFFSET ?
349 "#};
350
351 let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
352 let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
353 let min_minor_str_ref = min_minor_str.as_deref();
354 let min_patch_str_ref = min_patch_str.as_deref();
355 let request = json!({
356 "statement": statement,
357 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
358 "database": "EVENTS",
359 "schema": "PUBLIC",
360 "warehouse": "DBT",
361 "role": role,
362 "bindings": {
363 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
364 "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
365 "3": { "type": "TEXT", "value": after_date },
366 "4": { "type": "FIXED", "value": min_minor_str_ref },
367 "5": { "type": "FIXED", "value": min_minor_str_ref },
368 "6": { "type": "FIXED", "value": min_minor_str_ref },
369 "7": { "type": "FIXED", "value": min_patch_str_ref },
370 "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
371 "9": { "type": "FIXED", "value": offset.to_string() }
372 }
373 });
374
375 let response = run_sql_with_polling(
376 http_client.clone(),
377 &base_url,
378 &token,
379 &request,
380 &step_progress,
381 background_executor.clone(),
382 )
383 .await?;
384
385 let total_rows = response
386 .result_set_meta_data
387 .as_ref()
388 .and_then(|m| m.num_rows)
389 .unwrap_or(response.data.len() as i64);
390
391 let num_partitions = response
392 .result_set_meta_data
393 .as_ref()
394 .map(|m| m.partition_info.len())
395 .unwrap_or(1)
396 .max(1);
397
398 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
399 step_progress.set_substatus("parsing");
400
401 let column_indices = get_column_indices(
402 &response.result_set_meta_data,
403 &[
404 "request_id",
405 "device_id",
406 "time",
407 "input",
408 "prompt",
409 "output",
410 "was_shown",
411 "reason",
412 "zed_version",
413 ],
414 );
415
416 all_examples.extend(rejected_examples_from_response(&response, &column_indices)?);
417
418 if num_partitions > 1 {
419 let statement_handle = response
420 .statement_handle
421 .as_ref()
422 .context("response has multiple partitions but no statementHandle")?;
423
424 for partition in 1..num_partitions {
425 step_progress.set_substatus(format!(
426 "fetching partition {}/{}",
427 partition + 1,
428 num_partitions
429 ));
430
431 let partition_response = fetch_partition(
432 http_client.clone(),
433 &base_url,
434 &token,
435 statement_handle,
436 partition,
437 )
438 .await?;
439
440 all_examples.extend(rejected_examples_from_response(
441 &partition_response,
442 &column_indices,
443 )?);
444 }
445 }
446
447 step_progress.set_substatus("done");
448 }
449
450 Ok(all_examples)
451}
452
453pub async fn fetch_requested_examples_after(
454 http_client: Arc<dyn HttpClient>,
455 after_timestamps: &[String],
456 max_rows_per_timestamp: usize,
457 offset: usize,
458 background_executor: BackgroundExecutor,
459 min_capture_version: Option<MinCaptureVersion>,
460) -> Result<Vec<Example>> {
461 if after_timestamps.is_empty() {
462 return Ok(Vec::new());
463 }
464
465 let progress = Progress::global();
466
467 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
468 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
469 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
470 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
471 )?;
472 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
473
474 let mut all_examples = Vec::new();
475
476 for after_date in after_timestamps.iter() {
477 let step_progress_name = format!("requested>{after_date}");
478 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
479 step_progress.set_substatus("querying");
480
481 let statement = indoc! {r#"
482 SELECT
483 req.event_properties:request_id::string AS request_id,
484 req.device_id::string AS device_id,
485 req.time::string AS time,
486 req.event_properties:input AS input,
487 req.event_properties:zed_version::string AS zed_version
488 FROM events req
489 WHERE req.event_type = ?
490 AND req.event_properties:version = 'V3'
491 AND req.event_properties:input:can_collect_data = true
492 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
493 AND (? IS NULL OR (
494 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
495 OR (
496 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
497 AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
498 )
499 ))
500 ORDER BY req.time ASC
501 LIMIT ?
502 OFFSET ?
503 "#};
504
505 let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
506 let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
507 let min_minor_str_ref = min_minor_str.as_deref();
508 let min_patch_str_ref = min_patch_str.as_deref();
509 let request = json!({
510 "statement": statement,
511 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
512 "database": "EVENTS",
513 "schema": "PUBLIC",
514 "warehouse": "DBT",
515 "role": role,
516 "bindings": {
517 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
518 "2": { "type": "TEXT", "value": after_date },
519 "3": { "type": "FIXED", "value": min_minor_str_ref },
520 "4": { "type": "FIXED", "value": min_minor_str_ref },
521 "5": { "type": "FIXED", "value": min_minor_str_ref },
522 "6": { "type": "FIXED", "value": min_patch_str_ref },
523 "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
524 "8": { "type": "FIXED", "value": offset.to_string() }
525 }
526 });
527
528 let response = run_sql_with_polling(
529 http_client.clone(),
530 &base_url,
531 &token,
532 &request,
533 &step_progress,
534 background_executor.clone(),
535 )
536 .await?;
537
538 let total_rows = response
539 .result_set_meta_data
540 .as_ref()
541 .and_then(|m| m.num_rows)
542 .unwrap_or(response.data.len() as i64);
543
544 let num_partitions = response
545 .result_set_meta_data
546 .as_ref()
547 .map(|m| m.partition_info.len())
548 .unwrap_or(1)
549 .max(1);
550
551 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
552 step_progress.set_substatus("parsing");
553
554 let column_indices = get_column_indices(
555 &response.result_set_meta_data,
556 &["request_id", "device_id", "time", "input", "zed_version"],
557 );
558
559 all_examples.extend(requested_examples_from_response(
560 &response,
561 &column_indices,
562 )?);
563
564 if num_partitions > 1 {
565 let statement_handle = response
566 .statement_handle
567 .as_ref()
568 .context("response has multiple partitions but no statementHandle")?;
569
570 for partition in 1..num_partitions {
571 step_progress.set_substatus(format!(
572 "fetching partition {}/{}",
573 partition + 1,
574 num_partitions
575 ));
576
577 let partition_response = fetch_partition(
578 http_client.clone(),
579 &base_url,
580 &token,
581 statement_handle,
582 partition,
583 )
584 .await?;
585
586 all_examples.extend(requested_examples_from_response(
587 &partition_response,
588 &column_indices,
589 )?);
590 }
591 }
592
593 step_progress.set_substatus("done");
594 }
595
596 Ok(all_examples)
597}
598
599pub async fn fetch_rated_examples_after(
600 http_client: Arc<dyn HttpClient>,
601 inputs: &[(String, Option<EditPredictionRating>)],
602 max_rows_per_timestamp: usize,
603 offset: usize,
604 background_executor: BackgroundExecutor,
605 _min_capture_version: Option<MinCaptureVersion>,
606) -> Result<Vec<Example>> {
607 if inputs.is_empty() {
608 return Ok(Vec::new());
609 }
610
611 let progress = Progress::global();
612
613 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
614 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
615 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
616 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
617 )?;
618 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
619
620 let mut all_examples = Vec::new();
621
622 for (after_date, rating_filter) in inputs.iter() {
623 let filter_label = match rating_filter {
624 None => "",
625 Some(EditPredictionRating::Positive) => ":positive",
626 Some(EditPredictionRating::Negative) => ":negative",
627 };
628 let step_progress_name = format!("rated{filter_label}>{after_date}");
629 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
630 step_progress.set_substatus("querying");
631
632 let rating_value = rating_filter.as_ref().map(|r| match r {
633 EditPredictionRating::Positive => "Positive",
634 EditPredictionRating::Negative => "Negative",
635 });
636
637 let statement = indoc! {r#"
638 SELECT
639 rated.event_properties:request_id::string AS request_id,
640 rated.event_properties:inputs AS inputs,
641 rated.event_properties:output::string AS output,
642 rated.event_properties:rating::string AS rating,
643 rated.event_properties:feedback::string AS feedback,
644 rated.device_id::string AS device_id,
645 rated.time::string AS time,
646 deploy.event_properties:experiment_name::string AS experiment_name,
647 deploy.event_properties:environment::string AS environment,
648 rated.event_properties:zed_version::string AS zed_version
649 FROM events rated
650 LEFT JOIN events req
651 ON rated.event_properties:request_id::string = req.event_properties:request_id::string
652 AND req.event_type = ?
653 LEFT JOIN events deploy
654 ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string
655 AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string
656 AND deploy.event_type = ?
657 WHERE rated.event_type = ?
658 AND (? IS NULL OR rated.event_properties:rating::string = ?)
659 AND rated.time > TRY_TO_TIMESTAMP_NTZ(?)
660 AND rated.event_properties:inputs IS NOT NULL
661 AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
662 AND rated.event_properties:output IS NOT NULL
663 AND rated.event_properties:can_collect_data = true
664 ORDER BY rated.time ASC
665 LIMIT ?
666 OFFSET ?
667 "#};
668
669 let bindings = json!({
670 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
671 "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
672 "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
673 "4": { "type": "TEXT", "value": rating_value },
674 "5": { "type": "TEXT", "value": rating_value },
675 "6": { "type": "TEXT", "value": after_date },
676 "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
677 "8": { "type": "FIXED", "value": offset.to_string() }
678 });
679
680 let request = json!({
681 "statement": statement,
682 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
683 "database": "EVENTS",
684 "schema": "PUBLIC",
685 "warehouse": "DBT",
686 "role": role,
687 "bindings": bindings
688 });
689
690 let response = run_sql_with_polling(
691 http_client.clone(),
692 &base_url,
693 &token,
694 &request,
695 &step_progress,
696 background_executor.clone(),
697 )
698 .await?;
699
700 let total_rows = response
701 .result_set_meta_data
702 .as_ref()
703 .and_then(|m| m.num_rows)
704 .unwrap_or(response.data.len() as i64);
705
706 let num_partitions = response
707 .result_set_meta_data
708 .as_ref()
709 .map(|m| m.partition_info.len())
710 .unwrap_or(1)
711 .max(1);
712
713 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
714 step_progress.set_substatus("parsing");
715
716 let column_indices = get_column_indices(
717 &response.result_set_meta_data,
718 &[
719 "request_id",
720 "inputs",
721 "output",
722 "rating",
723 "feedback",
724 "device_id",
725 "time",
726 "experiment_name",
727 "environment",
728 "zed_version",
729 ],
730 );
731
732 all_examples.extend(rated_examples_from_response(&response, &column_indices)?);
733
734 if num_partitions > 1 {
735 let statement_handle = response
736 .statement_handle
737 .as_ref()
738 .context("response has multiple partitions but no statementHandle")?;
739
740 for partition in 1..num_partitions {
741 step_progress.set_substatus(format!(
742 "fetching partition {}/{}",
743 partition + 1,
744 num_partitions
745 ));
746
747 let partition_response = fetch_partition(
748 http_client.clone(),
749 &base_url,
750 &token,
751 statement_handle,
752 partition,
753 )
754 .await?;
755
756 all_examples.extend(rated_examples_from_response(
757 &partition_response,
758 &column_indices,
759 )?);
760 }
761 }
762
763 step_progress.set_substatus("done");
764 }
765
766 Ok(all_examples)
767}
768
769fn rated_examples_from_response<'a>(
770 response: &'a SnowflakeStatementResponse,
771 column_indices: &'a std::collections::HashMap<String, usize>,
772) -> Result<impl Iterator<Item = Example> + 'a> {
773 if let Some(code) = &response.code {
774 if code != SNOWFLAKE_SUCCESS_CODE {
775 anyhow::bail!(
776 "snowflake sql api returned error code={code} message={}",
777 response.message.as_deref().unwrap_or("<no message>")
778 );
779 }
780 }
781
782 let iter = response
783 .data
784 .iter()
785 .enumerate()
786 .filter_map(move |(row_index, data_row)| {
787 let get_string = |name: &str| -> Option<String> {
788 let index = column_indices.get(name).copied()?;
789 match data_row.get(index)? {
790 JsonValue::String(s) => Some(s.clone()),
791 JsonValue::Null => None,
792 other => Some(other.to_string()),
793 }
794 };
795
796 let get_json = |name: &str| -> Option<JsonValue> {
797 let index = column_indices.get(name).copied()?;
798 let value = data_row.get(index)?;
799 if value.is_null() {
800 return None;
801 }
802 match value {
803 JsonValue::String(s) => serde_json::from_str(s).ok(),
804 other => Some(other.clone()),
805 }
806 };
807
808 let request_id = get_string("request_id");
809 let inputs_json = get_json("inputs");
810 let inputs: Option<ZetaPromptInput> = match &inputs_json {
811 Some(v) => match serde_json::from_value(v.clone()) {
812 Ok(parsed) => Some(parsed),
813 Err(e) => {
814 log::warn!(
815 "skipping row {row_index}: failed to parse inputs - {e}",
816 );
817 return None;
818 }
819 },
820 None => None,
821 };
822 let output = get_string("output");
823 let rating = get_string("rating");
824 let feedback = get_string("feedback").unwrap_or_default();
825 let device_id = get_string("device_id");
826 let time = get_string("time");
827 let experiment_name = get_string("experiment_name");
828 let environment = get_string("environment");
829 let zed_version = get_string("zed_version");
830
831 match (inputs, output.clone(), rating.clone(), device_id.clone(), time.clone()) {
832 (Some(inputs), Some(output), Some(rating), Some(device_id), Some(time)) => {
833 Some(build_rated_example(
834 request_id,
835 device_id,
836 time,
837 inputs,
838 output,
839 rating,
840 feedback,
841 experiment_name,
842 environment,
843 zed_version,
844 ))
845 }
846 _ => {
847 log::warn!(
848 "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} device_id={:?} time={:?}",
849 inputs_json.is_some(),
850 output.is_some(),
851 rating.is_some(),
852 device_id.is_some(),
853 time.is_some(),
854 );
855 None
856 }
857 }
858 });
859
860 Ok(iter)
861}
862
863fn build_rated_example(
864 request_id: Option<String>,
865 device_id: String,
866 time: String,
867 input: ZetaPromptInput,
868 output: String,
869 rating: String,
870 feedback: String,
871 experiment_name: Option<String>,
872 environment: Option<String>,
873 zed_version: Option<String>,
874) -> Example {
875 let parsed_rating = if rating == "Positive" {
876 EditPredictionRating::Positive
877 } else {
878 EditPredictionRating::Negative
879 };
880 let is_positive = parsed_rating == EditPredictionRating::Positive;
881 let request_id = request_id.unwrap_or_else(|| format!("rated-{}-{}", device_id, time));
882
883 let mut tags = Vec::with_capacity(3);
884 tags.push(if is_positive {
885 "rated:positive".to_string()
886 } else {
887 "rated:negative".to_string()
888 });
889 if let Some(experiment) = experiment_name {
890 tags.push(format!("experiment:{experiment}"));
891 }
892 if let Some(env) = environment {
893 tags.push(format!("environment:{env}"));
894 }
895
896 let mut example =
897 build_example_from_snowflake(request_id, device_id, time, input, tags, None, zed_version);
898
899 example.spec.rating = Some(parsed_rating);
900
901 if !feedback.is_empty() {
902 example
903 .spec
904 .human_feedback
905 .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
906 }
907
908 if is_positive {
909 example.spec.expected_patches = vec![output];
910 } else {
911 example.spec.rejected_patch = Some(output);
912 }
913
914 example
915}
916
917fn requested_examples_from_response<'a>(
918 response: &'a SnowflakeStatementResponse,
919 column_indices: &'a std::collections::HashMap<String, usize>,
920) -> Result<impl Iterator<Item = Example> + 'a> {
921 if let Some(code) = &response.code {
922 if code != SNOWFLAKE_SUCCESS_CODE {
923 anyhow::bail!(
924 "snowflake sql api returned error code={code} message={}",
925 response.message.as_deref().unwrap_or("<no message>")
926 );
927 }
928 }
929
930 let iter = response
931 .data
932 .iter()
933 .enumerate()
934 .filter_map(move |(row_index, data_row)| {
935 let get_string = |name: &str| -> Option<String> {
936 let index = column_indices.get(name).copied()?;
937 match data_row.get(index)? {
938 JsonValue::String(s) => Some(s.clone()),
939 JsonValue::Null => None,
940 other => Some(other.to_string()),
941 }
942 };
943
944 let get_json = |name: &str| -> Option<JsonValue> {
945 let index = column_indices.get(name).copied()?;
946 let value = data_row.get(index)?;
947 if value.is_null() {
948 return None;
949 }
950 match value {
951 JsonValue::String(s) => serde_json::from_str(s).ok(),
952 other => Some(other.clone()),
953 }
954 };
955
956 let request_id_str = get_string("request_id");
957 let device_id = get_string("device_id");
958 let time = get_string("time");
959 let input_json = get_json("input");
960 let input: Option<ZetaPromptInput> =
961 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
962 let zed_version = get_string("zed_version");
963
964 match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
965 (Some(request_id), Some(device_id), Some(time), Some(input)) => {
966 Some(build_example_from_snowflake(
967 request_id,
968 device_id,
969 time,
970 input,
971 vec!["requested".to_string()],
972 None,
973 zed_version,
974 ))
975 }
976 _ => {
977 log::warn!(
978 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
979 request_id_str.is_some(),
980 device_id.is_some(),
981 time.is_some(),
982 input_json.is_some(),
983 );
984 None
985 }
986 }
987 });
988
989 Ok(iter)
990}
991
992fn rejected_examples_from_response<'a>(
993 response: &'a SnowflakeStatementResponse,
994 column_indices: &'a std::collections::HashMap<String, usize>,
995) -> Result<impl Iterator<Item = Example> + 'a> {
996 if let Some(code) = &response.code {
997 if code != SNOWFLAKE_SUCCESS_CODE {
998 anyhow::bail!(
999 "snowflake sql api returned error code={code} message={}",
1000 response.message.as_deref().unwrap_or("<no message>")
1001 );
1002 }
1003 }
1004
1005 let iter = response
1006 .data
1007 .iter()
1008 .enumerate()
1009 .filter_map(move |(row_index, data_row)| {
1010 let get_string = |name: &str| -> Option<String> {
1011 let index = column_indices.get(name).copied()?;
1012 match data_row.get(index)? {
1013 JsonValue::String(s) => Some(s.clone()),
1014 JsonValue::Null => None,
1015 other => Some(other.to_string()),
1016 }
1017 };
1018
1019 let get_json = |name: &str| -> Option<JsonValue> {
1020 let index = column_indices.get(name).copied()?;
1021 let value = data_row.get(index)?;
1022 if value.is_null() {
1023 return None;
1024 }
1025 match value {
1026 JsonValue::String(s) => serde_json::from_str(s).ok(),
1027 other => Some(other.clone()),
1028 }
1029 };
1030
1031 let get_bool = |name: &str| -> Option<bool> {
1032 let index = column_indices.get(name).copied()?;
1033 match data_row.get(index)? {
1034 JsonValue::Bool(b) => Some(*b),
1035 JsonValue::String(s) => s.parse().ok(),
1036 _ => None,
1037 }
1038 };
1039
1040 let request_id_str = get_string("request_id");
1041 let device_id = get_string("device_id");
1042 let time = get_string("time");
1043 let input_json = get_json("input");
1044 let input: Option<ZetaPromptInput> =
1045 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1046 let output = get_string("output");
1047 let was_shown = get_bool("was_shown");
1048 let reason = get_string("reason");
1049 let zed_version = get_string("zed_version");
1050
1051 match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1052 (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1053 Some(build_rejected_example(
1054 request_id,
1055 device_id,
1056 time,
1057 input,
1058 output,
1059 was_shown,
1060 reason,
1061 zed_version,
1062 ))
1063 }
1064 _ => {
1065 log::warn!(
1066 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1067 request_id_str.is_some(),
1068 device_id.is_some(),
1069 time.is_some(),
1070 input_json.is_some(),
1071 output.is_some(),
1072 was_shown.is_some(),
1073 reason.is_some()
1074 );
1075 None
1076 }
1077 }
1078 });
1079
1080 Ok(iter)
1081}
1082
1083fn build_rejected_example(
1084 request_id: String,
1085 device_id: String,
1086 time: String,
1087 input: ZetaPromptInput,
1088 output: String,
1089 was_shown: bool,
1090 reason: String,
1091 zed_version: Option<String>,
1092) -> Example {
1093 let rejected_patch = build_output_patch(
1094 &input.cursor_path,
1095 input.cursor_excerpt.as_ref(),
1096 &input.editable_range_in_excerpt,
1097 &output,
1098 );
1099 let mut example = build_example_from_snowflake(
1100 request_id,
1101 device_id,
1102 time,
1103 input,
1104 vec![format!("rejection:{}", reason.to_lowercase())],
1105 Some(RejectionInfo { reason, was_shown }),
1106 zed_version,
1107 );
1108 example.spec.rejected_patch = Some(rejected_patch);
1109 example
1110}
1111
1112struct RejectionInfo {
1113 reason: String,
1114 was_shown: bool,
1115}
1116
1117fn build_example_from_snowflake(
1118 request_id: String,
1119 device_id: String,
1120 time: String,
1121 input: ZetaPromptInput,
1122 tags: Vec<String>,
1123 rejection: Option<RejectionInfo>,
1124 zed_version: Option<String>,
1125) -> Example {
1126 let cursor_excerpt = input.cursor_excerpt.as_ref();
1127 let cursor_offset = input.cursor_offset_in_excerpt;
1128
1129 let mut edit_history = String::new();
1130 for event in &input.events {
1131 zeta_prompt::write_event(&mut edit_history, event);
1132 edit_history.push('\n');
1133 }
1134
1135 let (rejection_reason, was_shown) = match &rejection {
1136 Some(r) => (r.reason.clone(), r.was_shown),
1137 None => (String::new(), false),
1138 };
1139
1140 let spec = ExampleSpec {
1141 name: request_id.clone(),
1142 repository_url: String::new(),
1143 revision: String::new(),
1144 tags,
1145 reasoning: None,
1146 uncommitted_diff: String::new(),
1147 cursor_path: input.cursor_path.clone(),
1148 cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1149 edit_history,
1150 expected_patches: Vec::new(),
1151 rejected_patch: None,
1152 telemetry: Some(TelemetrySource {
1153 request_id,
1154 device_id,
1155 time,
1156 rejection_reason,
1157 was_shown,
1158 }),
1159 human_feedback: Vec::new(),
1160 rating: None,
1161 };
1162
1163 Example {
1164 spec,
1165 zed_version,
1166 prompt_inputs: Some(input),
1167 prompt: None,
1168 predictions: Vec::new(),
1169 score: Vec::new(),
1170 qa: Vec::new(),
1171 state: None,
1172 }
1173}
1174
1175fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1176 let before = &excerpt[..cursor_offset.min(excerpt.len())];
1177 let after = &excerpt[cursor_offset.min(excerpt.len())..];
1178 format!("{}[CURSOR_POSITION]{}", before, after)
1179}
1180
1181fn build_output_patch(
1182 cursor_path: &std::path::Path,
1183 cursor_excerpt: &str,
1184 editable_range: &std::ops::Range<usize>,
1185 model_output: &str,
1186) -> String {
1187 let old_text = &cursor_excerpt[editable_range.clone()];
1188
1189 let editable_start_row = cursor_excerpt[..editable_range.start]
1190 .chars()
1191 .filter(|&c| c == '\n')
1192 .count() as u32;
1193
1194 let diff_body = language::unified_diff_with_offsets(
1195 old_text,
1196 model_output,
1197 editable_start_row,
1198 editable_start_row,
1199 );
1200
1201 let mut patch = String::new();
1202 writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1203 writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1204 patch.push_str(&diff_body);
1205 patch
1206}
1207
1208pub(crate) fn get_column_indices(
1209 meta: &Option<SnowflakeResultSetMetaData>,
1210 names: &[&str],
1211) -> std::collections::HashMap<String, usize> {
1212 let mut indices = std::collections::HashMap::new();
1213 if let Some(meta) = meta {
1214 for (index, col) in meta.row_type.iter().enumerate() {
1215 for &name in names {
1216 if col.name.eq_ignore_ascii_case(name) {
1217 indices.insert(name.to_string(), index);
1218 }
1219 }
1220 }
1221 }
1222 indices
1223}