1use anyhow::{Context as _, Result};
2use flate2::read::GzDecoder;
3use gpui::BackgroundExecutor;
4use http_client::{AsyncBody, HttpClient, Method, Request};
5use indoc::indoc;
6use serde::Deserialize;
7use serde_json::{Value as JsonValue, json};
8use std::fmt::Write as _;
9use std::io::Read;
10use std::sync::Arc;
11use std::time::Duration;
12use telemetry_events::EditPredictionRating;
13
14use zeta_prompt::{ZetaFormat, ZetaPromptInput, excerpt_range_for_format};
15
16use crate::example::Example;
17use crate::progress::{InfoStyle, Progress, Step};
18const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment";
19use edit_prediction::example_spec::{ExampleSpec, TelemetrySource};
20
21pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
22pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
23const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
24const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
25const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
26const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
27
28/// Minimum Zed version for filtering captured examples.
29/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
30/// where `zed_version >= 0.224.1`.
31#[derive(Clone, Copy, Debug)]
32pub struct MinCaptureVersion {
33 pub minor: u32,
34 pub patch: u32,
35}
36
37const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
38const SETTLED_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
39pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
40pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
41
42/// Parse an input token of the form `captured-after:{timestamp}`.
43pub fn parse_captured_after_input(input: &str) -> Option<&str> {
44 input.strip_prefix("captured-after:")
45}
46
47/// Parse an input token of the form `rejected-after:{timestamp}`.
48pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
49 input.strip_prefix("rejected-after:")
50}
51
52/// Parse an input token of the form `requested-after:{timestamp}`.
53pub fn parse_requested_after_input(input: &str) -> Option<&str> {
54 input.strip_prefix("requested-after:")
55}
56
57/// Parse an input token of the form `settled-after:{timestamp}`.
58pub fn parse_settled_after_input(input: &str) -> Option<&str> {
59 input.strip_prefix("settled-after:")
60}
61
62/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
63/// or `rated-negative-after:{timestamp}`.
64/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
65pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
66 if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
67 Some((timestamp, Some(EditPredictionRating::Positive)))
68 } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
69 Some((timestamp, Some(EditPredictionRating::Negative)))
70 } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
71 Some((timestamp, None))
72 } else {
73 None
74 }
75}
76
77#[derive(Debug, Clone, Deserialize)]
78#[serde(rename_all = "camelCase")]
79pub(crate) struct SnowflakeStatementResponse {
80 #[serde(default)]
81 pub(crate) data: Vec<Vec<JsonValue>>,
82 #[serde(default)]
83 pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
84 #[serde(default)]
85 pub(crate) code: Option<String>,
86 #[serde(default)]
87 pub(crate) message: Option<String>,
88 #[serde(default)]
89 pub(crate) statement_handle: Option<String>,
90}
91
92#[derive(Debug, Clone, Deserialize)]
93#[serde(rename_all = "camelCase")]
94pub(crate) struct SnowflakeResultSetMetaData {
95 #[serde(default, rename = "rowType")]
96 row_type: Vec<SnowflakeColumnMeta>,
97 #[serde(default)]
98 num_rows: Option<i64>,
99 #[serde(default)]
100 partition_info: Vec<SnowflakePartitionInfo>,
101}
102
103#[derive(Debug, Clone, Deserialize)]
104#[serde(rename_all = "camelCase")]
105struct SnowflakePartitionInfo {}
106
107#[derive(Debug, Clone, Deserialize)]
108struct SnowflakeColumnMeta {
109 #[serde(default)]
110 name: String,
111}
112
113async fn run_sql_with_polling(
114 http_client: Arc<dyn HttpClient>,
115 base_url: &str,
116 token: &str,
117 request: &serde_json::Value,
118 step_progress: &crate::progress::StepProgress,
119 background_executor: BackgroundExecutor,
120) -> Result<SnowflakeStatementResponse> {
121 let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
122
123 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
124 let statement_handle = response
125 .statement_handle
126 .as_ref()
127 .context("async query response missing statementHandle")?
128 .clone();
129
130 for attempt in 1..=MAX_POLL_ATTEMPTS {
131 step_progress.set_substatus(format!("polling ({attempt})"));
132
133 background_executor.timer(POLL_INTERVAL).await;
134
135 response =
136 fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
137
138 if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
139 break;
140 }
141 }
142
143 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
144 anyhow::bail!(
145 "query still running after {} poll attempts ({} seconds)",
146 MAX_POLL_ATTEMPTS,
147 MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
148 );
149 }
150 }
151
152 Ok(response)
153}
154
155struct SnowflakeConfig {
156 token: String,
157 base_url: String,
158 role: Option<String>,
159}
160
161async fn fetch_examples_with_query(
162 http_client: Arc<dyn HttpClient>,
163 step_progress: &crate::progress::StepProgress,
164 background_executor: BackgroundExecutor,
165 statement: &str,
166 bindings: JsonValue,
167 timeout_seconds: u64,
168 required_columns: &[&str],
169 parse_response: for<'a> fn(
170 &'a SnowflakeStatementResponse,
171 &'a std::collections::HashMap<String, usize>,
172 ) -> Result<Box<dyn Iterator<Item = Example> + 'a>>,
173) -> Result<Vec<Example>> {
174 let snowflake = SnowflakeConfig {
175 token: std::env::var("EP_SNOWFLAKE_API_KEY")
176 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?,
177 base_url: std::env::var("EP_SNOWFLAKE_BASE_URL").context(
178 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
179 )?,
180 role: std::env::var("EP_SNOWFLAKE_ROLE").ok(),
181 };
182 let request = json!({
183 "statement": statement,
184 "timeout": timeout_seconds,
185 "database": "EVENTS",
186 "schema": "PUBLIC",
187 "warehouse": "DBT",
188 "role": snowflake.role.as_deref(),
189 "bindings": bindings
190 });
191
192 let response = run_sql_with_polling(
193 http_client.clone(),
194 &snowflake.base_url,
195 &snowflake.token,
196 &request,
197 step_progress,
198 background_executor,
199 )
200 .await?;
201
202 let total_rows = response
203 .result_set_meta_data
204 .as_ref()
205 .and_then(|meta| meta.num_rows)
206 .unwrap_or(response.data.len() as i64);
207 let partition_count = response
208 .result_set_meta_data
209 .as_ref()
210 .map(|meta| meta.partition_info.len())
211 .unwrap_or(1)
212 .max(1);
213
214 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
215 step_progress.set_substatus("parsing");
216
217 let column_indices = get_column_indices(&response.result_set_meta_data, required_columns);
218
219 let mut parsed_examples = Vec::with_capacity(total_rows as usize);
220 parsed_examples.extend(parse_response(&response, &column_indices)?);
221
222 if partition_count > 1 {
223 let statement_handle = response
224 .statement_handle
225 .as_ref()
226 .context("response has multiple partitions but no statementHandle")?;
227
228 for partition in 1..partition_count {
229 step_progress.set_substatus(format!(
230 "fetching partition {}/{}",
231 partition + 1,
232 partition_count
233 ));
234
235 let partition_response = fetch_partition(
236 http_client.clone(),
237 &snowflake.base_url,
238 &snowflake.token,
239 statement_handle,
240 partition,
241 )
242 .await?;
243
244 parsed_examples.extend(parse_response(&partition_response, &column_indices)?);
245 }
246 }
247
248 step_progress.set_substatus("done");
249 Ok(parsed_examples)
250}
251
252pub(crate) async fn fetch_partition(
253 http_client: Arc<dyn HttpClient>,
254 base_url: &str,
255 token: &str,
256 statement_handle: &str,
257 partition: usize,
258) -> Result<SnowflakeStatementResponse> {
259 let url = format!(
260 "{}/api/v2/statements/{}?partition={}",
261 base_url.trim_end_matches('/'),
262 statement_handle,
263 partition
264 );
265
266 let http_request = Request::builder()
267 .method(Method::GET)
268 .uri(url.as_str())
269 .header("Authorization", format!("Bearer {token}"))
270 .header(
271 "X-Snowflake-Authorization-Token-Type",
272 "PROGRAMMATIC_ACCESS_TOKEN",
273 )
274 .header("Accept", "application/json")
275 .header("Accept-Encoding", "gzip")
276 .header("User-Agent", "edit_prediction_cli")
277 .body(AsyncBody::empty())?;
278
279 let response = http_client
280 .send(http_request)
281 .await
282 .context("failed to send partition request to Snowflake SQL API")?;
283
284 let status = response.status();
285 let content_encoding = response
286 .headers()
287 .get("content-encoding")
288 .and_then(|v| v.to_str().ok())
289 .map(|s| s.to_lowercase());
290
291 let body_bytes = {
292 use futures::AsyncReadExt as _;
293
294 let mut body = response.into_body();
295 let mut bytes = Vec::new();
296 body.read_to_end(&mut bytes)
297 .await
298 .context("failed to read Snowflake SQL API partition response body")?;
299 bytes
300 };
301
302 let body_bytes = if content_encoding.as_deref() == Some("gzip") {
303 let mut decoder = GzDecoder::new(&body_bytes[..]);
304 let mut decompressed = Vec::new();
305 decoder
306 .read_to_end(&mut decompressed)
307 .context("failed to decompress gzip response")?;
308 decompressed
309 } else {
310 body_bytes
311 };
312
313 if !status.is_success() && status.as_u16() != 202 {
314 let body_text = String::from_utf8_lossy(&body_bytes);
315 anyhow::bail!(
316 "snowflake sql api partition request http {}: {}",
317 status.as_u16(),
318 body_text
319 );
320 }
321
322 if body_bytes.is_empty() {
323 anyhow::bail!(
324 "snowflake sql api partition {} returned empty response body (http {})",
325 partition,
326 status.as_u16()
327 );
328 }
329
330 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
331 let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
332 format!(
333 "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
334 partition,
335 status.as_u16(),
336 body_preview
337 )
338 })
339}
340
341pub(crate) async fn run_sql(
342 http_client: Arc<dyn HttpClient>,
343 base_url: &str,
344 token: &str,
345 request: &serde_json::Value,
346) -> Result<SnowflakeStatementResponse> {
347 let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
348
349 let request_body =
350 serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
351
352 let http_request = Request::builder()
353 .method(Method::POST)
354 .uri(url.as_str())
355 .header("Authorization", format!("Bearer {token}"))
356 .header(
357 "X-Snowflake-Authorization-Token-Type",
358 "PROGRAMMATIC_ACCESS_TOKEN",
359 )
360 .header("Content-Type", "application/json")
361 .header("Accept", "application/json")
362 .header("User-Agent", "edit_prediction_cli")
363 .body(AsyncBody::from(request_body.clone()))?;
364
365 let response = http_client
366 .send(http_request)
367 .await
368 .context("failed to send request to Snowflake SQL API")?;
369
370 let status = response.status();
371 let body_bytes = {
372 use futures::AsyncReadExt as _;
373
374 let mut body = response.into_body();
375 let mut bytes = Vec::new();
376 body.read_to_end(&mut bytes)
377 .await
378 .context("failed to read Snowflake SQL API response body")?;
379 bytes
380 };
381
382 if !status.is_success() && status.as_u16() != 202 {
383 let body_text = String::from_utf8_lossy(&body_bytes);
384 anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
385 }
386
387 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
388 .context("failed to parse Snowflake SQL API response JSON")
389}
390
391pub async fn fetch_rejected_examples_after(
392 http_client: Arc<dyn HttpClient>,
393 after_timestamps: &[String],
394 max_rows_per_timestamp: usize,
395 offset: usize,
396 background_executor: BackgroundExecutor,
397 min_capture_version: Option<MinCaptureVersion>,
398) -> Result<Vec<Example>> {
399 if after_timestamps.is_empty() {
400 return Ok(Vec::new());
401 }
402
403 let progress = Progress::global();
404
405 let mut all_examples = Vec::new();
406
407 for after_date in after_timestamps.iter() {
408 let step_progress_name = format!("rejected>{after_date}");
409 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
410 step_progress.set_substatus("querying");
411
412 let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
413 let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
414 let min_minor_str_ref = min_minor_str.as_deref();
415 let min_patch_str_ref = min_patch_str.as_deref();
416
417 let statement = indoc! {r#"
418 SELECT
419 req.event_properties:request_id::string AS request_id,
420 req.device_id::string AS device_id,
421 req.time::string AS time,
422 req.event_properties:input AS input,
423 req.event_properties:prompt::string AS prompt,
424 req.event_properties:output::string AS output,
425 rej.event_properties:was_shown::boolean AS was_shown,
426 rej.event_properties:reason::string AS reason,
427 req.event_properties:zed_version::string AS zed_version
428 FROM events req
429 INNER JOIN events rej
430 ON req.event_properties:request_id = rej.event_properties:request_id
431 WHERE req.event_type = ?
432 AND rej.event_type = ?
433 AND req.event_properties:version = 'V3'
434 AND rej.event_properties:was_shown = true
435 AND req.event_properties:input:can_collect_data = true
436 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
437 AND (? IS NULL OR (
438 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
439 OR (
440 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
441 AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
442 )
443 ))
444 ORDER BY req.time ASC
445 LIMIT ?
446 OFFSET ?
447 "#};
448
449 let bindings = json!({
450 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
451 "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
452 "3": { "type": "TEXT", "value": after_date },
453 "4": { "type": "FIXED", "value": min_minor_str_ref },
454 "5": { "type": "FIXED", "value": min_minor_str_ref },
455 "6": { "type": "FIXED", "value": min_minor_str_ref },
456 "7": { "type": "FIXED", "value": min_patch_str_ref },
457 "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
458 "9": { "type": "FIXED", "value": offset.to_string() }
459 });
460
461 let examples = fetch_examples_with_query(
462 http_client.clone(),
463 &step_progress,
464 background_executor.clone(),
465 statement,
466 bindings,
467 DEFAULT_STATEMENT_TIMEOUT_SECONDS,
468 &[
469 "request_id",
470 "device_id",
471 "time",
472 "input",
473 "prompt",
474 "output",
475 "was_shown",
476 "reason",
477 "zed_version",
478 ],
479 rejected_examples_from_response,
480 )
481 .await?;
482
483 all_examples.extend(examples);
484 }
485
486 Ok(all_examples)
487}
488
489pub async fn fetch_requested_examples_after(
490 http_client: Arc<dyn HttpClient>,
491 after_timestamps: &[String],
492 max_rows_per_timestamp: usize,
493 offset: usize,
494 background_executor: BackgroundExecutor,
495 min_capture_version: Option<MinCaptureVersion>,
496) -> Result<Vec<Example>> {
497 if after_timestamps.is_empty() {
498 return Ok(Vec::new());
499 }
500
501 let progress = Progress::global();
502
503 let mut all_examples = Vec::new();
504
505 for after_date in after_timestamps.iter() {
506 let step_progress_name = format!("requested>{after_date}");
507 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
508 step_progress.set_substatus("querying");
509
510 let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
511 let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
512 let min_minor_str_ref = min_minor_str.as_deref();
513 let min_patch_str_ref = min_patch_str.as_deref();
514
515 let statement = indoc! {r#"
516 SELECT
517 req.event_properties:request_id::string AS request_id,
518 req.device_id::string AS device_id,
519 req.time::string AS time,
520 req.event_properties:input AS input,
521 req.event_properties:zed_version::string AS zed_version
522 FROM events req
523 WHERE req.event_type = ?
524 AND req.event_properties:version = 'V3'
525 AND req.event_properties:input:can_collect_data = true
526 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
527 AND (? IS NULL OR (
528 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
529 OR (
530 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
531 AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
532 )
533 ))
534 ORDER BY req.time ASC
535 LIMIT ?
536 OFFSET ?
537 "#};
538
539 let bindings = json!({
540 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
541 "2": { "type": "TEXT", "value": after_date },
542 "3": { "type": "FIXED", "value": min_minor_str_ref },
543 "4": { "type": "FIXED", "value": min_minor_str_ref },
544 "5": { "type": "FIXED", "value": min_minor_str_ref },
545 "6": { "type": "FIXED", "value": min_patch_str_ref },
546 "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
547 "8": { "type": "FIXED", "value": offset.to_string() }
548 });
549
550 let examples = fetch_examples_with_query(
551 http_client.clone(),
552 &step_progress,
553 background_executor.clone(),
554 statement,
555 bindings,
556 DEFAULT_STATEMENT_TIMEOUT_SECONDS,
557 &["request_id", "device_id", "time", "input", "zed_version"],
558 requested_examples_from_response,
559 )
560 .await?;
561
562 all_examples.extend(examples);
563 }
564
565 Ok(all_examples)
566}
567
568pub async fn fetch_captured_examples_after(
569 http_client: Arc<dyn HttpClient>,
570 after_timestamps: &[String],
571 max_rows_per_timestamp: usize,
572 offset: usize,
573 background_executor: BackgroundExecutor,
574 min_capture_version: Option<MinCaptureVersion>,
575) -> Result<Vec<Example>> {
576 if after_timestamps.is_empty() {
577 return Ok(Vec::new());
578 }
579
580 let progress = Progress::global();
581
582 let mut all_examples = Vec::new();
583
584 for after_date in after_timestamps.iter() {
585 let step_progress_name = format!("captured>{after_date}");
586 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
587 step_progress.set_substatus("querying");
588
589 let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
590 let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
591 let min_minor_str_ref = min_minor_str.as_deref();
592 let min_patch_str_ref = min_patch_str.as_deref();
593
594 let statement = indoc! {r#"
595 SELECT
596 settled.event_properties:request_id::string AS request_id,
597 settled.device_id::string AS device_id,
598 settled.time::string AS time,
599 req.event_properties:input AS input,
600 settled.event_properties:settled_editable_region::string AS settled_editable_region,
601 settled.event_properties:example AS example,
602 req.event_properties:zed_version::string AS zed_version
603 FROM events settled
604 INNER JOIN events req
605 ON settled.event_properties:request_id::string = req.event_properties:request_id::string
606 WHERE settled.event_type = ?
607 AND req.event_type = ?
608 AND req.event_properties:version = 'V3'
609 AND req.event_properties:input:can_collect_data = true
610 AND settled.event_properties:example IS NOT NULL
611 AND TYPEOF(settled.event_properties:example) != 'NULL_VALUE'
612 AND settled.time > TRY_TO_TIMESTAMP_NTZ(?)
613 AND (? IS NULL OR (
614 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
615 OR (
616 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
617 AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
618 )
619 ))
620 ORDER BY settled.time ASC
621 LIMIT ?
622 OFFSET ?
623 "#};
624
625 let bindings = json!({
626 "1": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT },
627 "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
628 "3": { "type": "TEXT", "value": after_date },
629 "4": { "type": "FIXED", "value": min_minor_str_ref },
630 "5": { "type": "FIXED", "value": min_minor_str_ref },
631 "6": { "type": "FIXED", "value": min_minor_str_ref },
632 "7": { "type": "FIXED", "value": min_patch_str_ref },
633 "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
634 "9": { "type": "FIXED", "value": offset.to_string() }
635 });
636
637 let examples = fetch_examples_with_query(
638 http_client.clone(),
639 &step_progress,
640 background_executor.clone(),
641 statement,
642 bindings,
643 DEFAULT_STATEMENT_TIMEOUT_SECONDS,
644 &[
645 "request_id",
646 "device_id",
647 "time",
648 "input",
649 "settled_editable_region",
650 "example",
651 "zed_version",
652 ],
653 captured_examples_from_response,
654 )
655 .await?;
656
657 all_examples.extend(examples);
658 }
659
660 Ok(all_examples)
661}
662
663pub async fn fetch_settled_examples_after(
664 http_client: Arc<dyn HttpClient>,
665 after_timestamps: &[String],
666 max_rows_per_timestamp: usize,
667 offset: usize,
668 background_executor: BackgroundExecutor,
669 min_capture_version: Option<MinCaptureVersion>,
670) -> Result<Vec<Example>> {
671 if after_timestamps.is_empty() {
672 return Ok(Vec::new());
673 }
674
675 let progress = Progress::global();
676
677 let mut all_examples = Vec::new();
678
679 for after_date in after_timestamps.iter() {
680 let step_progress_name = format!("settled>{after_date}");
681 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
682 step_progress.set_substatus("querying");
683
684 let _ = min_capture_version;
685
686 let statement = indoc! {r#"
687 WITH requested AS (
688 SELECT
689 req.event_properties:request_id::string AS request_id,
690 req.device_id::string AS device_id,
691 req.time AS req_time,
692 req.time::string AS time,
693 req.event_properties:input AS input,
694 req.event_properties:format::string AS requested_format,
695 req.event_properties:output::string AS requested_output,
696 req.event_properties:zed_version::string AS zed_version
697 FROM events req
698 WHERE req.event_type = ?
699 AND req.event_properties:version = 'V3'
700 AND req.event_properties:input:can_collect_data = true
701 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
702 )
703 SELECT
704 req.request_id AS request_id,
705 req.device_id AS device_id,
706 req.time AS time,
707 req.input AS input,
708 req.requested_output AS requested_output,
709 settled.event_properties:settled_editable_region::string AS settled_editable_region,
710 req.requested_format AS requested_format,
711 req.zed_version AS zed_version
712 FROM requested req
713 INNER JOIN events settled
714 ON req.request_id = settled.event_properties:request_id::string
715 WHERE settled.event_type = ?
716 ORDER BY req.req_time ASC
717 LIMIT ?
718 OFFSET ?
719 "#};
720
721 let bindings = json!({
722 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
723 "2": { "type": "TEXT", "value": after_date },
724 "3": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT },
725 "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
726 "5": { "type": "FIXED", "value": offset.to_string() }
727 });
728
729 let examples = fetch_examples_with_query(
730 http_client.clone(),
731 &step_progress,
732 background_executor.clone(),
733 statement,
734 bindings,
735 SETTLED_STATEMENT_TIMEOUT_SECONDS,
736 &[
737 "request_id",
738 "device_id",
739 "time",
740 "input",
741 "requested_output",
742 "settled_editable_region",
743 "requested_format",
744 "zed_version",
745 ],
746 settled_examples_from_response,
747 )
748 .await?;
749
750 all_examples.extend(examples);
751 }
752
753 Ok(all_examples)
754}
755
756pub async fn fetch_rated_examples_after(
757 http_client: Arc<dyn HttpClient>,
758 inputs: &[(String, Option<EditPredictionRating>)],
759 max_rows_per_timestamp: usize,
760 offset: usize,
761 background_executor: BackgroundExecutor,
762 _min_capture_version: Option<MinCaptureVersion>,
763) -> Result<Vec<Example>> {
764 if inputs.is_empty() {
765 return Ok(Vec::new());
766 }
767
768 let progress = Progress::global();
769
770 let mut all_examples = Vec::new();
771
772 for (after_date, rating_filter) in inputs.iter() {
773 let filter_label = match rating_filter {
774 None => "",
775 Some(EditPredictionRating::Positive) => ":positive",
776 Some(EditPredictionRating::Negative) => ":negative",
777 };
778 let step_progress_name = format!("rated{filter_label}>{after_date}");
779 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
780 step_progress.set_substatus("querying");
781
782 let rating_value = rating_filter.as_ref().map(|rating| match rating {
783 EditPredictionRating::Positive => "Positive",
784 EditPredictionRating::Negative => "Negative",
785 });
786
787 let statement = indoc! {r#"
788 SELECT
789 rated.event_properties:request_id::string AS request_id,
790 rated.event_properties:inputs AS inputs,
791 rated.event_properties:output::string AS output,
792 rated.event_properties:rating::string AS rating,
793 rated.event_properties:feedback::string AS feedback,
794 rated.device_id::string AS device_id,
795 rated.time::string AS time,
796 deploy.event_properties:experiment_name::string AS experiment_name,
797 deploy.event_properties:environment::string AS environment,
798 rated.event_properties:zed_version::string AS zed_version
799 FROM events rated
800 LEFT JOIN events req
801 ON rated.event_properties:request_id::string = req.event_properties:request_id::string
802 AND req.event_type = ?
803 LEFT JOIN events deploy
804 ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string
805 AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string
806 AND deploy.event_type = ?
807 WHERE rated.event_type = ?
808 AND (? IS NULL OR rated.event_properties:rating::string = ?)
809 AND rated.time > TRY_TO_TIMESTAMP_NTZ(?)
810 AND rated.event_properties:inputs IS NOT NULL
811 AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
812 AND rated.event_properties:output IS NOT NULL
813 AND rated.event_properties:inputs:can_collect_data = true
814 ORDER BY rated.time ASC
815 LIMIT ?
816 OFFSET ?
817 "#};
818
819 let bindings = json!({
820 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
821 "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
822 "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
823 "4": { "type": "TEXT", "value": rating_value },
824 "5": { "type": "TEXT", "value": rating_value },
825 "6": { "type": "TEXT", "value": after_date },
826 "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
827 "8": { "type": "FIXED", "value": offset.to_string() }
828 });
829
830 let examples = fetch_examples_with_query(
831 http_client.clone(),
832 &step_progress,
833 background_executor.clone(),
834 statement,
835 bindings,
836 DEFAULT_STATEMENT_TIMEOUT_SECONDS,
837 &[
838 "request_id",
839 "inputs",
840 "output",
841 "rating",
842 "feedback",
843 "device_id",
844 "time",
845 "experiment_name",
846 "environment",
847 "zed_version",
848 ],
849 rated_examples_from_response,
850 )
851 .await?;
852
853 all_examples.extend(examples);
854 }
855
856 Ok(all_examples)
857}
858
859fn rated_examples_from_response<'a>(
860 response: &'a SnowflakeStatementResponse,
861 column_indices: &'a std::collections::HashMap<String, usize>,
862) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
863 if let Some(code) = &response.code {
864 if code != SNOWFLAKE_SUCCESS_CODE {
865 anyhow::bail!(
866 "snowflake sql api returned error code={code} message={}",
867 response.message.as_deref().unwrap_or("<no message>")
868 );
869 }
870 }
871
872 let iter = response
873 .data
874 .iter()
875 .enumerate()
876 .filter_map(move |(row_index, data_row)| {
877 let get_string = |name: &str| -> Option<String> {
878 let index = column_indices.get(name).copied()?;
879 match data_row.get(index)? {
880 JsonValue::String(s) => Some(s.clone()),
881 JsonValue::Null => None,
882 other => Some(other.to_string()),
883 }
884 };
885
886 let get_json = |name: &str| -> Option<JsonValue> {
887 let index = column_indices.get(name).copied()?;
888 let value = data_row.get(index)?;
889 if value.is_null() {
890 return None;
891 }
892 match value {
893 JsonValue::String(s) => serde_json::from_str(s).ok(),
894 other => Some(other.clone()),
895 }
896 };
897
898 let request_id = get_string("request_id");
899 let inputs_json = get_json("inputs");
900 let inputs: Option<ZetaPromptInput> = match &inputs_json {
901 Some(v) => match serde_json::from_value(v.clone()) {
902 Ok(parsed) => Some(parsed),
903 Err(e) => {
904 log::warn!(
905 "skipping row {row_index}: failed to parse inputs - {e}",
906 );
907 return None;
908 }
909 },
910 None => None,
911 };
912 let output = get_string("output");
913 let rating = get_string("rating");
914 let feedback = get_string("feedback").unwrap_or_default();
915 let device_id = get_string("device_id");
916 let time = get_string("time");
917 let experiment_name = get_string("experiment_name");
918 let environment = get_string("environment");
919 let zed_version = get_string("zed_version");
920
921 match (inputs, output.clone(), rating.clone(), time.clone()) {
922 (Some(inputs), Some(output), Some(rating), Some(time)) => {
923 Some(build_rated_example(
924 request_id,
925 device_id.unwrap_or_default(),
926 time,
927 inputs,
928 output,
929 rating,
930 feedback,
931 experiment_name,
932 environment,
933 zed_version,
934 ))
935 }
936 _ => {
937 log::warn!(
938 "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} time={:?}",
939 inputs_json.is_some(),
940 output.is_some(),
941 rating.is_some(),
942 time.is_some(),
943 );
944 None
945 }
946 }
947 });
948
949 Ok(Box::new(iter))
950}
951
952fn build_rated_example(
953 request_id: Option<String>,
954 device_id: String,
955 time: String,
956 input: ZetaPromptInput,
957 output: String,
958 rating: String,
959 feedback: String,
960 experiment_name: Option<String>,
961 environment: Option<String>,
962 zed_version: Option<String>,
963) -> Example {
964 let parsed_rating = if rating == "Positive" {
965 EditPredictionRating::Positive
966 } else {
967 EditPredictionRating::Negative
968 };
969 let is_positive = parsed_rating == EditPredictionRating::Positive;
970 let request_id = request_id.unwrap_or_else(|| format!("rated-{}-{}", device_id, time));
971
972 let mut tags = Vec::with_capacity(3);
973 tags.push(if is_positive {
974 "rated:positive".to_string()
975 } else {
976 "rated:negative".to_string()
977 });
978 if let Some(experiment) = experiment_name {
979 tags.push(format!("experiment:{experiment}"));
980 }
981 if let Some(env) = environment {
982 tags.push(format!("environment:{env}"));
983 }
984
985 let mut example =
986 build_example_from_snowflake(request_id, device_id, time, input, tags, None, zed_version);
987
988 example.spec.rating = Some(parsed_rating);
989
990 if !feedback.is_empty() {
991 example
992 .spec
993 .human_feedback
994 .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
995 }
996
997 if is_positive {
998 example.spec.expected_patches = vec![output];
999 } else {
1000 example.spec.rejected_patch = Some(output);
1001 }
1002
1003 example
1004}
1005
1006fn requested_examples_from_response<'a>(
1007 response: &'a SnowflakeStatementResponse,
1008 column_indices: &'a std::collections::HashMap<String, usize>,
1009) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1010 if let Some(code) = &response.code {
1011 if code != SNOWFLAKE_SUCCESS_CODE {
1012 anyhow::bail!(
1013 "snowflake sql api returned error code={code} message={}",
1014 response.message.as_deref().unwrap_or("<no message>")
1015 );
1016 }
1017 }
1018
1019 let iter = response
1020 .data
1021 .iter()
1022 .enumerate()
1023 .filter_map(move |(row_index, data_row)| {
1024 let get_string = |name: &str| -> Option<String> {
1025 let index = column_indices.get(name).copied()?;
1026 match data_row.get(index)? {
1027 JsonValue::String(s) => Some(s.clone()),
1028 JsonValue::Null => None,
1029 other => Some(other.to_string()),
1030 }
1031 };
1032
1033 let get_json = |name: &str| -> Option<JsonValue> {
1034 let index = column_indices.get(name).copied()?;
1035 let value = data_row.get(index)?;
1036 if value.is_null() {
1037 return None;
1038 }
1039 match value {
1040 JsonValue::String(s) => serde_json::from_str(s).ok(),
1041 other => Some(other.clone()),
1042 }
1043 };
1044
1045 let request_id_str = get_string("request_id");
1046 let device_id = get_string("device_id");
1047 let time = get_string("time");
1048 let input_json = get_json("input");
1049 let input: Option<ZetaPromptInput> =
1050 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1051 let zed_version = get_string("zed_version");
1052
1053 match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
1054 (Some(request_id), Some(device_id), Some(time), Some(input)) => {
1055 Some(build_example_from_snowflake(
1056 request_id,
1057 device_id,
1058 time,
1059 input,
1060 vec!["requested".to_string()],
1061 None,
1062 zed_version,
1063 ))
1064 }
1065 _ => {
1066 log::warn!(
1067 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
1068 request_id_str.is_some(),
1069 device_id.is_some(),
1070 time.is_some(),
1071 input_json.is_some(),
1072 );
1073 None
1074 }
1075 }
1076 });
1077
1078 Ok(Box::new(iter))
1079}
1080
1081fn settled_examples_from_response<'a>(
1082 response: &'a SnowflakeStatementResponse,
1083 column_indices: &'a std::collections::HashMap<String, usize>,
1084) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1085 if let Some(code) = &response.code {
1086 if code != SNOWFLAKE_SUCCESS_CODE {
1087 anyhow::bail!(
1088 "snowflake sql api returned error code={code} message={}",
1089 response.message.as_deref().unwrap_or("<no message>")
1090 );
1091 }
1092 }
1093
1094 let iter = response
1095 .data
1096 .iter()
1097 .enumerate()
1098 .filter_map(move |(row_index, data_row)| {
1099 let get_value = |name: &str| -> Option<JsonValue> {
1100 let index = column_indices.get(name).copied()?;
1101 let value = data_row.get(index)?;
1102 if value.is_null() {
1103 None
1104 } else {
1105 Some(value.clone())
1106 }
1107 };
1108
1109 let get_string = |name: &str| -> Option<String> {
1110 match get_value(name)? {
1111 JsonValue::String(s) => Some(s),
1112 other => Some(other.to_string()),
1113 }
1114 };
1115
1116 let parse_json_value = |raw: Option<&JsonValue>| -> Option<JsonValue> {
1117 let value = raw?;
1118 match value {
1119 JsonValue::String(s) => serde_json::from_str::<JsonValue>(s).ok(),
1120 other => Some(other.clone()),
1121 }
1122 };
1123
1124 let request_id_str = get_string("request_id");
1125 let device_id = get_string("device_id");
1126 let time = get_string("time");
1127 let input_raw = get_value("input");
1128 let input_json = parse_json_value(input_raw.as_ref());
1129 let input: Option<ZetaPromptInput> = input_json
1130 .as_ref()
1131 .and_then(|parsed| serde_json::from_value(parsed.clone()).ok());
1132 let requested_output = get_string("requested_output");
1133 let settled_editable_region = get_string("settled_editable_region");
1134 let requested_format =
1135 get_string("requested_format").and_then(|s| ZetaFormat::parse(&s).ok());
1136 let zed_version = get_string("zed_version");
1137
1138 match (
1139 request_id_str.clone(),
1140 device_id.clone(),
1141 time.clone(),
1142 input.clone(),
1143 requested_output.clone(),
1144 settled_editable_region.clone(),
1145 requested_format,
1146 ) {
1147 (
1148 Some(request_id),
1149 Some(device_id),
1150 Some(time),
1151 Some(input),
1152 Some(requested_output),
1153 Some(settled_editable_region),
1154 Some(requested_format),
1155 ) => Some(build_settled_example(
1156 request_id,
1157 device_id,
1158 time,
1159 input,
1160 requested_output,
1161 settled_editable_region,
1162 requested_format,
1163 zed_version,
1164 )),
1165 _ => {
1166 let mut missing_fields = Vec::new();
1167
1168 if request_id_str.is_none() {
1169 missing_fields.push("request_id");
1170 }
1171 if device_id.is_none() {
1172 missing_fields.push("device_id");
1173 }
1174 if time.is_none() {
1175 missing_fields.push("time");
1176 }
1177 if input_raw.is_none() || input_json.is_none() || input.is_none() {
1178 missing_fields.push("input");
1179 }
1180 if requested_output.is_none() {
1181 missing_fields.push("requested_output");
1182 }
1183 if settled_editable_region.is_none() {
1184 missing_fields.push("settled_editable_region");
1185 }
1186 if requested_format.is_none() {
1187 missing_fields.push("requested_format");
1188 }
1189
1190 log::warn!(
1191 "skipping settled row {row_index}: [{}]",
1192 missing_fields.join(", "),
1193 );
1194 None
1195 }
1196 }
1197 });
1198
1199 Ok(Box::new(iter))
1200}
1201
1202fn captured_examples_from_response<'a>(
1203 response: &'a SnowflakeStatementResponse,
1204 column_indices: &'a std::collections::HashMap<String, usize>,
1205) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1206 if let Some(code) = &response.code {
1207 if code != SNOWFLAKE_SUCCESS_CODE {
1208 anyhow::bail!(
1209 "snowflake sql api returned error code={code} message={}",
1210 response.message.as_deref().unwrap_or("<no message>")
1211 );
1212 }
1213 }
1214
1215 let iter = response
1216 .data
1217 .iter()
1218 .enumerate()
1219 .filter_map(move |(row_index, data_row)| {
1220 let get_value = |name: &str| -> Option<JsonValue> {
1221 let index = column_indices.get(name).copied()?;
1222 let value = data_row.get(index)?;
1223 if value.is_null() {
1224 None
1225 } else {
1226 Some(value.clone())
1227 }
1228 };
1229
1230 let get_string = |name: &str| -> Option<String> {
1231 match get_value(name)? {
1232 JsonValue::String(s) => Some(s),
1233 other => Some(other.to_string()),
1234 }
1235 };
1236
1237 let parse_json_value = |raw: Option<&JsonValue>| -> Option<JsonValue> {
1238 let value = raw?;
1239 match value {
1240 JsonValue::String(s) => serde_json::from_str::<JsonValue>(s).ok(),
1241 other => Some(other.clone()),
1242 }
1243 };
1244
1245 let request_id = get_string("request_id");
1246 let device_id = get_string("device_id");
1247 let time = get_string("time");
1248 let input_raw = get_value("input");
1249 let input_json = parse_json_value(input_raw.as_ref());
1250 let input: Option<ZetaPromptInput> = input_json
1251 .as_ref()
1252 .and_then(|parsed| serde_json::from_value(parsed.clone()).ok());
1253 let example_raw = get_value("example");
1254 let example_json = parse_json_value(example_raw.as_ref());
1255 let example_spec: Option<ExampleSpec> = example_json.as_ref().and_then(|parsed| {
1256 serde_json::from_value(parsed.clone())
1257 .or_else(|_| {
1258 parsed
1259 .as_str()
1260 .and_then(|markdown| ExampleSpec::from_markdown(markdown).ok())
1261 .ok_or_else(|| {
1262 serde_json::Error::io(std::io::Error::other("not markdown"))
1263 })
1264 })
1265 .ok()
1266 });
1267 let has_example_spec = example_spec.is_some();
1268 let settled_editable_region = get_string("settled_editable_region");
1269 let zed_version = get_string("zed_version");
1270
1271 match (
1272 request_id.clone(),
1273 device_id.clone(),
1274 time.clone(),
1275 input.clone(),
1276 example_spec,
1277 settled_editable_region.clone(),
1278 ) {
1279 (
1280 Some(request_id),
1281 Some(device_id),
1282 Some(time),
1283 Some(input),
1284 Some(example_spec),
1285 Some(settled_editable_region),
1286 ) => Some(build_captured_example(
1287 request_id,
1288 device_id,
1289 time,
1290 input,
1291 example_spec,
1292 settled_editable_region,
1293 zed_version,
1294 )),
1295 _ => {
1296 let mut missing_fields = Vec::new();
1297
1298 if request_id.is_none() {
1299 missing_fields.push("request_id");
1300 }
1301 if device_id.is_none() {
1302 missing_fields.push("device_id");
1303 }
1304 if time.is_none() {
1305 missing_fields.push("time");
1306 }
1307 if input_raw.is_none() || input_json.is_none() || input.is_none() {
1308 missing_fields.push("input");
1309 }
1310 if example_raw.is_none() || !has_example_spec {
1311 missing_fields.push("example");
1312 }
1313 if settled_editable_region.is_none() {
1314 missing_fields.push("settled_editable_region");
1315 }
1316
1317 log::warn!(
1318 "skipping captured row {row_index}: [{}]",
1319 missing_fields.join(", "),
1320 );
1321 None
1322 }
1323 }
1324 });
1325
1326 Ok(Box::new(iter))
1327}
1328
1329fn build_settled_example(
1330 request_id: String,
1331 device_id: String,
1332 time: String,
1333 input: ZetaPromptInput,
1334 requested_output: String,
1335 settled_editable_region: String,
1336 requested_format: ZetaFormat,
1337 zed_version: Option<String>,
1338) -> Example {
1339 let requested_editable_range =
1340 excerpt_range_for_format(requested_format, &input.excerpt_ranges).0;
1341
1342 let base_cursor_excerpt = input.cursor_excerpt.to_string();
1343
1344 let requested_range_is_valid = requested_editable_range.start <= requested_editable_range.end
1345 && requested_editable_range.end <= base_cursor_excerpt.len();
1346 let mut example = build_example_from_snowflake(
1347 request_id.clone(),
1348 device_id,
1349 time,
1350 input,
1351 vec!["settled".to_string()],
1352 None,
1353 zed_version,
1354 );
1355
1356 if !requested_range_is_valid {
1357 log::warn!(
1358 "skipping malformed requested range for request {}: requested={:?} (base_len={})",
1359 request_id,
1360 requested_editable_range,
1361 base_cursor_excerpt.len(),
1362 );
1363 return example;
1364 }
1365
1366 let settled_replacement = settled_editable_region.as_str();
1367 let rejected_patch = build_output_patch(
1368 &example.spec.cursor_path,
1369 &base_cursor_excerpt,
1370 &requested_editable_range,
1371 &requested_output,
1372 );
1373 let expected_patch = build_output_patch(
1374 &example.spec.cursor_path,
1375 &base_cursor_excerpt,
1376 &requested_editable_range,
1377 settled_replacement,
1378 );
1379
1380 example.spec.expected_patches = vec![expected_patch];
1381 example.spec.rejected_patch = Some(rejected_patch);
1382 example
1383}
1384
1385fn build_captured_example(
1386 request_id: String,
1387 device_id: String,
1388 time: String,
1389 input: ZetaPromptInput,
1390 mut example_spec: ExampleSpec,
1391 settled_editable_region: String,
1392 zed_version: Option<String>,
1393) -> Example {
1394 let expected_patch = build_output_patch(
1395 &input.cursor_path,
1396 input.cursor_excerpt.as_ref(),
1397 &input.excerpt_ranges.editable_350,
1398 settled_editable_region.as_str(),
1399 );
1400
1401 example_spec.expected_patches = vec![expected_patch];
1402 example_spec.telemetry = Some(TelemetrySource {
1403 request_id,
1404 device_id,
1405 time,
1406 rejection_reason: String::new(),
1407 was_shown: false,
1408 });
1409
1410 Example {
1411 spec: example_spec,
1412 zed_version,
1413 prompt_inputs: Some(input),
1414 prompt: None,
1415 predictions: Vec::new(),
1416 score: Vec::new(),
1417 qa: Vec::new(),
1418 state: None,
1419 }
1420}
1421
1422fn rejected_examples_from_response<'a>(
1423 response: &'a SnowflakeStatementResponse,
1424 column_indices: &'a std::collections::HashMap<String, usize>,
1425) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1426 if let Some(code) = &response.code {
1427 if code != SNOWFLAKE_SUCCESS_CODE {
1428 anyhow::bail!(
1429 "snowflake sql api returned error code={code} message={}",
1430 response.message.as_deref().unwrap_or("<no message>")
1431 );
1432 }
1433 }
1434
1435 let iter = response
1436 .data
1437 .iter()
1438 .enumerate()
1439 .filter_map(move |(row_index, data_row)| {
1440 let get_string = |name: &str| -> Option<String> {
1441 let index = column_indices.get(name).copied()?;
1442 match data_row.get(index)? {
1443 JsonValue::String(s) => Some(s.clone()),
1444 JsonValue::Null => None,
1445 other => Some(other.to_string()),
1446 }
1447 };
1448
1449 let get_json = |name: &str| -> Option<JsonValue> {
1450 let index = column_indices.get(name).copied()?;
1451 let value = data_row.get(index)?;
1452 if value.is_null() {
1453 return None;
1454 }
1455 match value {
1456 JsonValue::String(s) => serde_json::from_str(s).ok(),
1457 other => Some(other.clone()),
1458 }
1459 };
1460
1461 let get_bool = |name: &str| -> Option<bool> {
1462 let index = column_indices.get(name).copied()?;
1463 match data_row.get(index)? {
1464 JsonValue::Bool(b) => Some(*b),
1465 JsonValue::String(s) => s.parse().ok(),
1466 _ => None,
1467 }
1468 };
1469
1470 let request_id_str = get_string("request_id");
1471 let device_id = get_string("device_id");
1472 let time = get_string("time");
1473 let input_json = get_json("input");
1474 let input: Option<ZetaPromptInput> =
1475 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1476 let output = get_string("output");
1477 let was_shown = get_bool("was_shown");
1478 let reason = get_string("reason");
1479 let zed_version = get_string("zed_version");
1480
1481 match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1482 (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1483 Some(build_rejected_example(
1484 request_id,
1485 device_id,
1486 time,
1487 input,
1488 output,
1489 was_shown,
1490 reason,
1491 zed_version,
1492 ))
1493 }
1494 _ => {
1495 log::warn!(
1496 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1497 request_id_str.is_some(),
1498 device_id.is_some(),
1499 time.is_some(),
1500 input_json.is_some(),
1501 output.is_some(),
1502 was_shown.is_some(),
1503 reason.is_some()
1504 );
1505 None
1506 }
1507 }
1508 });
1509
1510 Ok(Box::new(iter))
1511}
1512
1513fn build_rejected_example(
1514 request_id: String,
1515 device_id: String,
1516 time: String,
1517 input: ZetaPromptInput,
1518 output: String,
1519 was_shown: bool,
1520 reason: String,
1521 zed_version: Option<String>,
1522) -> Example {
1523 let rejected_patch = build_output_patch(
1524 &input.cursor_path,
1525 input.cursor_excerpt.as_ref(),
1526 &input.excerpt_ranges.editable_350,
1527 &output,
1528 );
1529 let mut example = build_example_from_snowflake(
1530 request_id,
1531 device_id,
1532 time,
1533 input,
1534 vec![format!("rejection:{}", reason.to_lowercase())],
1535 Some(RejectionInfo { reason, was_shown }),
1536 zed_version,
1537 );
1538 example.spec.rejected_patch = Some(rejected_patch);
1539 example
1540}
1541
1542struct RejectionInfo {
1543 reason: String,
1544 was_shown: bool,
1545}
1546
1547fn build_example_from_snowflake(
1548 request_id: String,
1549 device_id: String,
1550 time: String,
1551 input: ZetaPromptInput,
1552 tags: Vec<String>,
1553 rejection: Option<RejectionInfo>,
1554 zed_version: Option<String>,
1555) -> Example {
1556 let cursor_excerpt = input.cursor_excerpt.as_ref();
1557 let cursor_offset = input.cursor_offset_in_excerpt;
1558
1559 let mut edit_history = String::new();
1560 for event in &input.events {
1561 zeta_prompt::write_event(&mut edit_history, event);
1562 edit_history.push('\n');
1563 }
1564
1565 let (rejection_reason, was_shown) = match &rejection {
1566 Some(r) => (r.reason.clone(), r.was_shown),
1567 None => (String::new(), false),
1568 };
1569
1570 let spec = ExampleSpec {
1571 name: request_id.clone(),
1572 repository_url: String::new(),
1573 revision: String::new(),
1574 tags,
1575 reasoning: None,
1576 uncommitted_diff: String::new(),
1577 cursor_path: input.cursor_path.clone(),
1578 cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1579 edit_history,
1580 expected_patches: Vec::new(),
1581 rejected_patch: None,
1582 telemetry: Some(TelemetrySource {
1583 request_id,
1584 device_id,
1585 time,
1586 rejection_reason,
1587 was_shown,
1588 }),
1589 human_feedback: Vec::new(),
1590 rating: None,
1591 };
1592
1593 Example {
1594 spec,
1595 zed_version,
1596 prompt_inputs: Some(input),
1597 prompt: None,
1598 predictions: Vec::new(),
1599 score: Vec::new(),
1600 qa: Vec::new(),
1601 state: None,
1602 }
1603}
1604
1605fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1606 let before = &excerpt[..cursor_offset.min(excerpt.len())];
1607 let after = &excerpt[cursor_offset.min(excerpt.len())..];
1608 format!("{}[CURSOR_POSITION]{}", before, after)
1609}
1610
1611fn build_output_patch(
1612 cursor_path: &std::path::Path,
1613 cursor_excerpt: &str,
1614 editable_range: &std::ops::Range<usize>,
1615 model_output: &str,
1616) -> String {
1617 let old_text = &cursor_excerpt[editable_range.clone()];
1618
1619 let editable_start_row = cursor_excerpt[..editable_range.start]
1620 .chars()
1621 .filter(|&c| c == '\n')
1622 .count() as u32;
1623
1624 let diff_body = language::unified_diff_with_offsets(
1625 old_text,
1626 model_output,
1627 editable_start_row,
1628 editable_start_row,
1629 );
1630
1631 let mut patch = String::new();
1632 writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1633 writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1634 patch.push_str(&diff_body);
1635 patch
1636}
1637
1638pub(crate) fn get_column_indices(
1639 meta: &Option<SnowflakeResultSetMetaData>,
1640 names: &[&str],
1641) -> std::collections::HashMap<String, usize> {
1642 let mut indices = std::collections::HashMap::new();
1643 if let Some(meta) = meta {
1644 for (index, col) in meta.row_type.iter().enumerate() {
1645 for &name in names {
1646 if col.name.eq_ignore_ascii_case(name) {
1647 indices.insert(name.to_string(), index);
1648 }
1649 }
1650 }
1651 }
1652 indices
1653}