1use anyhow::{Context as _, Result};
2use flate2::read::GzDecoder;
3use gpui::BackgroundExecutor;
4use http_client::{AsyncBody, HttpClient, Method, Request};
5use indoc::indoc;
6use serde::Deserialize;
7use serde_json::{Value as JsonValue, json};
8use std::collections::HashMap;
9use std::fmt::Write as _;
10use std::io::Read;
11use std::sync::Arc;
12use std::time::Duration;
13use telemetry_events::EditPredictionRating;
14
15use zeta_prompt::{ZetaFormat, ZetaPromptInput, excerpt_range_for_format};
16
17use crate::PredictionProvider;
18use crate::example::{Example, ExamplePrompt};
19use crate::progress::{InfoStyle, Progress, Step};
20use edit_prediction::example_spec::{ExampleSpec, TelemetrySource};
21
22pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
23pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
24const SNOWFLAKE_TIMEOUT_CODE: &str = "000630";
25
26/// Minimum Zed version for filtering captured examples.
27/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
28/// where `zed_version >= 0.224.1`.
29#[derive(Clone, Copy, Debug)]
30pub struct MinCaptureVersion {
31 pub minor: u32,
32 pub patch: u32,
33}
34
35pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
36const PARTITION_FETCH_MAX_RETRIES: usize = 3;
37const PARTITION_FETCH_RETRY_DELAYS: [Duration; PARTITION_FETCH_MAX_RETRIES] = [
38 Duration::from_millis(500),
39 Duration::from_secs(1),
40 Duration::from_secs(2),
41];
42
43/// Parse an input token of the form `captured-after:{timestamp}`.
44pub fn parse_captured_after_input(input: &str) -> Option<&str> {
45 input.strip_prefix("captured-after:")
46}
47
48/// Parse an input token of the form `rejected-after:{timestamp}`.
49pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
50 input.strip_prefix("rejected-after:")
51}
52
53/// Parse an input token of the form `requested-after:{timestamp}`.
54pub fn parse_requested_after_input(input: &str) -> Option<&str> {
55 input.strip_prefix("requested-after:")
56}
57
58/// Parse an input token of the form `settled-after:{timestamp}`.
59pub fn parse_settled_after_input(input: &str) -> Option<&str> {
60 input.strip_prefix("settled-after:")
61}
62
63/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
64/// or `rated-negative-after:{timestamp}`.
65/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
66pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
67 if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
68 Some((timestamp, Some(EditPredictionRating::Positive)))
69 } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
70 Some((timestamp, Some(EditPredictionRating::Negative)))
71 } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
72 Some((timestamp, None))
73 } else {
74 None
75 }
76}
77
78#[derive(Debug, Clone, Deserialize)]
79#[serde(rename_all = "camelCase")]
80pub(crate) struct SnowflakeStatementResponse {
81 #[serde(default)]
82 pub(crate) data: Vec<Vec<JsonValue>>,
83 #[serde(default)]
84 pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
85 #[serde(default)]
86 pub(crate) code: Option<String>,
87 #[serde(default)]
88 pub(crate) message: Option<String>,
89 #[serde(default)]
90 pub(crate) statement_handle: Option<String>,
91}
92
93#[derive(Debug, Clone, Deserialize)]
94#[serde(rename_all = "camelCase")]
95pub(crate) struct SnowflakeResultSetMetaData {
96 #[serde(default, rename = "rowType")]
97 row_type: Vec<SnowflakeColumnMeta>,
98 #[serde(default)]
99 num_rows: Option<i64>,
100 #[serde(default)]
101 partition_info: Vec<SnowflakePartitionInfo>,
102}
103
104#[derive(Debug, Clone, Deserialize)]
105#[serde(rename_all = "camelCase")]
106struct SnowflakePartitionInfo {}
107
108#[derive(Debug, Clone, Deserialize)]
109struct SnowflakeColumnMeta {
110 #[serde(default)]
111 name: String,
112}
113
114async fn run_sql_with_polling(
115 http_client: Arc<dyn HttpClient>,
116 base_url: &str,
117 token: &str,
118 request: &serde_json::Value,
119 step_progress: &crate::progress::StepProgress,
120 background_executor: BackgroundExecutor,
121) -> Result<SnowflakeStatementResponse> {
122 let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
123
124 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
125 let statement_handle = response
126 .statement_handle
127 .as_ref()
128 .context("async query response missing statementHandle")?
129 .clone();
130
131 for attempt in 0.. {
132 step_progress.set_substatus(format!("polling ({attempt})"));
133
134 background_executor.timer(POLL_INTERVAL).await;
135
136 response = fetch_partition_with_retries(
137 http_client.clone(),
138 base_url,
139 token,
140 &statement_handle,
141 0,
142 background_executor.clone(),
143 )
144 .await?;
145
146 if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
147 break;
148 }
149 }
150 }
151
152 Ok(response)
153}
154
155struct SnowflakeConfig {
156 token: String,
157 base_url: String,
158 role: Option<String>,
159}
160
161#[derive(Clone)]
162struct QueryRetryState {
163 resume_after: String,
164 remaining_limit: Option<usize>,
165 offset: usize,
166}
167
168async fn fetch_examples_with_query<MakeBindings>(
169 http_client: Arc<dyn HttpClient>,
170 step_progress: &crate::progress::StepProgress,
171 background_executor: BackgroundExecutor,
172 statement: &str,
173 initial_retry_state: QueryRetryState,
174 make_bindings: MakeBindings,
175 required_columns: &[&str],
176 parse_response: for<'a> fn(
177 &'a SnowflakeStatementResponse,
178 &'a HashMap<String, usize>,
179 ) -> Result<Box<dyn Iterator<Item = Example> + 'a>>,
180) -> Result<Vec<Example>>
181where
182 MakeBindings: Fn(&QueryRetryState) -> JsonValue,
183{
184 let snowflake = SnowflakeConfig {
185 token: std::env::var("EP_SNOWFLAKE_API_KEY")
186 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?,
187 base_url: std::env::var("EP_SNOWFLAKE_BASE_URL").context(
188 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
189 )?,
190 role: std::env::var("EP_SNOWFLAKE_ROLE").ok(),
191 };
192
193 let mut requested_columns = required_columns.to_vec();
194 if !requested_columns.contains(&"continuation_time") {
195 requested_columns.push("continuation_time");
196 }
197
198 let mut parsed_examples = Vec::new();
199 let mut retry_state = initial_retry_state;
200 let mut retry_count = 0usize;
201
202 loop {
203 let bindings = make_bindings(&retry_state);
204 let request = json!({
205 "statement": statement,
206 "database": "EVENTS",
207 "schema": "PUBLIC",
208 "warehouse": "DBT",
209 "role": snowflake.role.as_deref(),
210 "bindings": bindings
211 });
212
213 let response = match run_sql_with_polling(
214 http_client.clone(),
215 &snowflake.base_url,
216 &snowflake.token,
217 &request,
218 step_progress,
219 background_executor.clone(),
220 )
221 .await
222 {
223 Ok(response) => response,
224 Err(error) => {
225 if is_snowflake_timeout_error(&error) && !parsed_examples.is_empty() {
226 retry_count += 1;
227 step_progress.set_substatus(format!(
228 "retrying from {} ({retry_count})",
229 retry_state.resume_after
230 ));
231 continue;
232 }
233
234 return Err(error);
235 }
236 };
237
238 let total_rows = response
239 .result_set_meta_data
240 .as_ref()
241 .and_then(|meta| meta.num_rows)
242 .unwrap_or(response.data.len() as i64);
243 let partition_count = response
244 .result_set_meta_data
245 .as_ref()
246 .map(|meta| meta.partition_info.len())
247 .unwrap_or(1)
248 .max(1);
249
250 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
251 step_progress.set_substatus("parsing");
252
253 let column_indices = get_column_indices(&response.result_set_meta_data, &requested_columns);
254 let mut rows_fetched_this_attempt = 0usize;
255 let mut timed_out_fetching_partition = false;
256
257 parsed_examples.extend(parse_response(&response, &column_indices)?);
258 rows_fetched_this_attempt += response.data.len();
259 let mut last_continuation_time_this_attempt =
260 last_continuation_timestamp_from_response(&response, &column_indices);
261
262 if partition_count > 1 {
263 let statement_handle = response
264 .statement_handle
265 .as_ref()
266 .context("response has multiple partitions but no statementHandle")?;
267
268 for partition in 1..partition_count {
269 step_progress.set_substatus(format!(
270 "fetching partition {}/{}",
271 partition + 1,
272 partition_count
273 ));
274
275 let partition_response = match fetch_partition_with_retries(
276 http_client.clone(),
277 &snowflake.base_url,
278 &snowflake.token,
279 statement_handle,
280 partition,
281 background_executor.clone(),
282 )
283 .await
284 {
285 Ok(response) => response,
286 Err(error) => {
287 if is_snowflake_timeout_error(&error) && rows_fetched_this_attempt > 0 {
288 timed_out_fetching_partition = true;
289 break;
290 }
291
292 return Err(error);
293 }
294 };
295
296 parsed_examples.extend(parse_response(&partition_response, &column_indices)?);
297 rows_fetched_this_attempt += partition_response.data.len();
298
299 if let Some(partition_continuation_time) =
300 last_continuation_timestamp_from_response(&partition_response, &column_indices)
301 {
302 last_continuation_time_this_attempt = Some(partition_continuation_time);
303 }
304 }
305 }
306
307 if rows_fetched_this_attempt == 0 {
308 step_progress.set_substatus("done");
309 return Ok(parsed_examples);
310 }
311
312 if let Some(remaining_limit_value) = &mut retry_state.remaining_limit {
313 *remaining_limit_value =
314 remaining_limit_value.saturating_sub(rows_fetched_this_attempt);
315 if *remaining_limit_value == 0 {
316 step_progress.set_substatus("done");
317 return Ok(parsed_examples);
318 }
319 }
320
321 if !timed_out_fetching_partition {
322 step_progress.set_substatus("done");
323 return Ok(parsed_examples);
324 }
325
326 let Some(last_continuation_time_this_attempt) = last_continuation_time_this_attempt else {
327 step_progress.set_substatus("done");
328 return Ok(parsed_examples);
329 };
330
331 retry_state.resume_after = last_continuation_time_this_attempt;
332 retry_state.offset = 0;
333 retry_count += 1;
334 step_progress.set_substatus(format!(
335 "retrying from {} ({retry_count})",
336 retry_state.resume_after
337 ));
338 }
339}
340
341pub(crate) async fn fetch_partition(
342 http_client: Arc<dyn HttpClient>,
343 base_url: &str,
344 token: &str,
345 statement_handle: &str,
346 partition: usize,
347) -> Result<SnowflakeStatementResponse> {
348 let url = format!(
349 "{}/api/v2/statements/{}?partition={}",
350 base_url.trim_end_matches('/'),
351 statement_handle,
352 partition
353 );
354
355 let http_request = Request::builder()
356 .method(Method::GET)
357 .uri(url.as_str())
358 .header("Authorization", format!("Bearer {token}"))
359 .header(
360 "X-Snowflake-Authorization-Token-Type",
361 "PROGRAMMATIC_ACCESS_TOKEN",
362 )
363 .header("Accept", "application/json")
364 .header("Accept-Encoding", "gzip")
365 .header("User-Agent", "edit_prediction_cli")
366 .body(AsyncBody::empty())?;
367
368 let response = http_client
369 .send(http_request)
370 .await
371 .context("failed to send partition request to Snowflake SQL API")?;
372
373 let status = response.status();
374 let content_encoding = response
375 .headers()
376 .get("content-encoding")
377 .and_then(|v| v.to_str().ok())
378 .map(|s| s.to_lowercase());
379
380 let body_bytes = {
381 use futures::AsyncReadExt as _;
382
383 let mut body = response.into_body();
384 let mut bytes = Vec::new();
385 body.read_to_end(&mut bytes)
386 .await
387 .context("failed to read Snowflake SQL API partition response body")?;
388 bytes
389 };
390
391 let body_bytes = if content_encoding.as_deref() == Some("gzip") {
392 let mut decoder = GzDecoder::new(&body_bytes[..]);
393 let mut decompressed = Vec::new();
394 decoder
395 .read_to_end(&mut decompressed)
396 .context("failed to decompress gzip response")?;
397 decompressed
398 } else {
399 body_bytes
400 };
401
402 if !status.is_success() && status.as_u16() != 202 {
403 let body_text = String::from_utf8_lossy(&body_bytes);
404 anyhow::bail!(
405 "snowflake sql api partition request http {}: {}",
406 status.as_u16(),
407 body_text
408 );
409 }
410
411 if body_bytes.is_empty() {
412 anyhow::bail!(
413 "snowflake sql api partition {} returned empty response body (http {})",
414 partition,
415 status.as_u16()
416 );
417 }
418
419 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
420 let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
421 format!(
422 "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
423 partition,
424 status.as_u16(),
425 body_preview
426 )
427 })
428}
429
430async fn fetch_partition_with_retries(
431 http_client: Arc<dyn HttpClient>,
432 base_url: &str,
433 token: &str,
434 statement_handle: &str,
435 partition: usize,
436 background_executor: BackgroundExecutor,
437) -> Result<SnowflakeStatementResponse> {
438 let mut last_error = None;
439
440 for retry_attempt in 0..=PARTITION_FETCH_MAX_RETRIES {
441 match fetch_partition(
442 http_client.clone(),
443 base_url,
444 token,
445 statement_handle,
446 partition,
447 )
448 .await
449 {
450 Ok(response) => return Ok(response),
451 Err(error) => {
452 if retry_attempt == PARTITION_FETCH_MAX_RETRIES
453 || !is_transient_partition_fetch_error(&error)
454 {
455 return Err(error);
456 }
457
458 last_error = Some(error);
459 background_executor
460 .timer(PARTITION_FETCH_RETRY_DELAYS[retry_attempt])
461 .await;
462 }
463 }
464 }
465
466 match last_error {
467 Some(error) => Err(error),
468 None => anyhow::bail!("partition fetch retry loop exited without a result"),
469 }
470}
471
472fn is_transient_partition_fetch_error(error: &anyhow::Error) -> bool {
473 error.chain().any(|cause| {
474 let message = cause.to_string();
475 message.contains("failed to read Snowflake SQL API partition response body")
476 || message.contains("unexpected EOF")
477 || message.contains("peer closed connection without sending TLS close_notify")
478 })
479}
480
481pub(crate) async fn run_sql(
482 http_client: Arc<dyn HttpClient>,
483 base_url: &str,
484 token: &str,
485 request: &serde_json::Value,
486) -> Result<SnowflakeStatementResponse> {
487 let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
488
489 let request_body =
490 serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
491
492 let http_request = Request::builder()
493 .method(Method::POST)
494 .uri(url.as_str())
495 .header("Authorization", format!("Bearer {token}"))
496 .header(
497 "X-Snowflake-Authorization-Token-Type",
498 "PROGRAMMATIC_ACCESS_TOKEN",
499 )
500 .header("Content-Type", "application/json")
501 .header("Accept", "application/json")
502 .header("User-Agent", "edit_prediction_cli")
503 .body(AsyncBody::from(request_body.clone()))?;
504
505 let response = http_client
506 .send(http_request)
507 .await
508 .context("failed to send request to Snowflake SQL API")?;
509
510 let status = response.status();
511 let body_bytes = {
512 use futures::AsyncReadExt as _;
513
514 let mut body = response.into_body();
515 let mut bytes = Vec::new();
516 body.read_to_end(&mut bytes)
517 .await
518 .context("failed to read Snowflake SQL API response body")?;
519 bytes
520 };
521
522 let snowflake_response = serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
523 .context("failed to parse Snowflake SQL API response JSON")?;
524
525 if !status.is_success() && status.as_u16() != 202 && !is_timeout_response(&snowflake_response) {
526 let body_text = String::from_utf8_lossy(&body_bytes);
527 anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
528 }
529
530 if is_timeout_response(&snowflake_response) {
531 anyhow::bail!(
532 "snowflake sql api timed out code={} message={}",
533 snowflake_response.code.as_deref().unwrap_or("<no code>"),
534 snowflake_response
535 .message
536 .as_deref()
537 .unwrap_or("<no message>")
538 );
539 }
540
541 Ok(snowflake_response)
542}
543
544pub async fn fetch_rejected_examples_after(
545 http_client: Arc<dyn HttpClient>,
546 after_timestamps: &[String],
547 max_rows_per_timestamp: Option<usize>,
548 offset: usize,
549 background_executor: BackgroundExecutor,
550 min_capture_version: Option<MinCaptureVersion>,
551) -> Result<Vec<Example>> {
552 if after_timestamps.is_empty() {
553 return Ok(Vec::new());
554 }
555
556 let progress = Progress::global();
557
558 let mut all_examples = Vec::new();
559
560 for after_date in after_timestamps.iter() {
561 let step_progress_name = format!("rejected>{after_date}");
562 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
563 step_progress.set_substatus("querying");
564
565 let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
566 let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
567 let min_minor_str_ref = min_minor_str.as_deref();
568 let min_patch_str_ref = min_patch_str.as_deref();
569
570 let statement = indoc! {r#"
571 SELECT
572 ep_request_id AS request_id,
573 device_id AS device_id,
574 requested_at::string AS continuation_time,
575 requested_at::string AS time,
576 input_payload AS input,
577 prompt AS prompt,
578 requested_output AS output,
579 is_ep_shown_before_rejected AS was_shown,
580 ep_rejected_reason AS reason,
581 zed_version AS zed_version
582 FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
583 WHERE ep_outcome LIKE 'Rejected%'
584 AND is_ep_shown_before_rejected = true
585 AND requested_at > TRY_TO_TIMESTAMP_NTZ(?)
586 AND (? IS NULL OR (
587 TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) > ?
588 OR (
589 TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) = ?
590 AND TRY_CAST(SPLIT_PART(SPLIT_PART(zed_version, '.', 3), '+', 1) AS INTEGER) >= ?
591 )
592 ))
593 ORDER BY requested_at ASC
594 LIMIT ?
595 OFFSET ?
596 "#};
597
598 let examples = fetch_examples_with_query(
599 http_client.clone(),
600 &step_progress,
601 background_executor.clone(),
602 statement,
603 QueryRetryState {
604 resume_after: after_date.clone(),
605 remaining_limit: max_rows_per_timestamp,
606 offset,
607 },
608 |retry_state| {
609 json!({
610 "1": { "type": "TEXT", "value": retry_state.resume_after },
611 "2": { "type": "FIXED", "value": min_minor_str_ref },
612 "3": { "type": "FIXED", "value": min_minor_str_ref },
613 "4": { "type": "FIXED", "value": min_minor_str_ref },
614 "5": { "type": "FIXED", "value": min_patch_str_ref },
615 "6": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
616 "7": { "type": "FIXED", "value": retry_state.offset.to_string() }
617 })
618 },
619 &[
620 "request_id",
621 "device_id",
622 "time",
623 "input",
624 "prompt",
625 "output",
626 "was_shown",
627 "reason",
628 "zed_version",
629 ],
630 rejected_examples_from_response,
631 )
632 .await?;
633
634 all_examples.extend(examples);
635 }
636
637 Ok(all_examples)
638}
639
640fn format_limit(limit: Option<usize>) -> String {
641 return limit.map(|l| l.to_string()).unwrap_or("NULL".to_string());
642}
643
644pub async fn fetch_requested_examples_after(
645 http_client: Arc<dyn HttpClient>,
646 after_timestamps: &[String],
647 max_rows_per_timestamp: Option<usize>,
648 offset: usize,
649 background_executor: BackgroundExecutor,
650 min_capture_version: Option<MinCaptureVersion>,
651) -> Result<Vec<Example>> {
652 if after_timestamps.is_empty() {
653 return Ok(Vec::new());
654 }
655
656 let progress = Progress::global();
657
658 let mut all_examples = Vec::new();
659
660 for after_date in after_timestamps.iter() {
661 let step_progress_name = format!("requested>{after_date}");
662 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
663 step_progress.set_substatus("querying");
664
665 let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
666 let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
667 let min_minor_str_ref = min_minor_str.as_deref();
668 let min_patch_str_ref = min_patch_str.as_deref();
669
670 let statement = indoc! {r#"
671 SELECT
672 ep_request_id AS request_id,
673 device_id AS device_id,
674 requested_at::string AS continuation_time,
675 requested_at::string AS time,
676 input_payload AS input,
677 zed_version AS zed_version
678 FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
679 WHERE requested_at > TRY_TO_TIMESTAMP_NTZ(?)
680 AND (? IS NULL OR (
681 TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) > ?
682 OR (
683 TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) = ?
684 AND TRY_CAST(SPLIT_PART(SPLIT_PART(zed_version, '.', 3), '+', 1) AS INTEGER) >= ?
685 )
686 ))
687 ORDER BY requested_at ASC
688 LIMIT ?
689 OFFSET ?
690 "#};
691
692 let examples = fetch_examples_with_query(
693 http_client.clone(),
694 &step_progress,
695 background_executor.clone(),
696 statement,
697 QueryRetryState {
698 resume_after: after_date.clone(),
699 remaining_limit: max_rows_per_timestamp,
700 offset,
701 },
702 |retry_state| {
703 json!({
704 "1": { "type": "TEXT", "value": retry_state.resume_after },
705 "2": { "type": "FIXED", "value": min_minor_str_ref },
706 "3": { "type": "FIXED", "value": min_minor_str_ref },
707 "4": { "type": "FIXED", "value": min_minor_str_ref },
708 "5": { "type": "FIXED", "value": min_patch_str_ref },
709 "6": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
710 "7": { "type": "FIXED", "value": retry_state.offset.to_string() }
711 })
712 },
713 &["request_id", "device_id", "time", "input", "zed_version"],
714 requested_examples_from_response,
715 )
716 .await?;
717
718 all_examples.extend(examples);
719 }
720
721 Ok(all_examples)
722}
723
724pub async fn fetch_captured_examples_after(
725 http_client: Arc<dyn HttpClient>,
726 after_timestamps: &[String],
727 max_rows_per_timestamp: Option<usize>,
728 offset: usize,
729 background_executor: BackgroundExecutor,
730 min_capture_version: Option<MinCaptureVersion>,
731) -> Result<Vec<Example>> {
732 if after_timestamps.is_empty() {
733 return Ok(Vec::new());
734 }
735
736 let progress = Progress::global();
737
738 let mut all_examples = Vec::new();
739
740 for after_date in after_timestamps.iter() {
741 let step_progress_name = format!("captured>{after_date}");
742 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
743 step_progress.set_substatus("querying");
744
745 let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
746 let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
747 let min_minor_str_ref = min_minor_str.as_deref();
748 let min_patch_str_ref = min_patch_str.as_deref();
749
750 let statement = indoc! {r#"
751 SELECT
752 ep_request_id AS request_id,
753 device_id AS device_id,
754 requested_at::string AS continuation_time,
755 requested_at::string AS time,
756 input_payload AS input,
757 settled_editable_region AS settled_editable_region,
758 example_payload AS example,
759 zed_version AS zed_version
760 FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
761 WHERE settled_editable_region IS NOT NULL
762 AND example_payload IS NOT NULL
763 AND requested_at > TRY_TO_TIMESTAMP_NTZ(?)
764 AND (? IS NULL OR (
765 TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) > ?
766 OR (
767 TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) = ?
768 AND TRY_CAST(SPLIT_PART(SPLIT_PART(zed_version, '.', 3), '+', 1) AS INTEGER) >= ?
769 )
770 ))
771 ORDER BY requested_at ASC
772 LIMIT ?
773 OFFSET ?
774 "#};
775
776 let examples = fetch_examples_with_query(
777 http_client.clone(),
778 &step_progress,
779 background_executor.clone(),
780 statement,
781 QueryRetryState {
782 resume_after: after_date.clone(),
783 remaining_limit: max_rows_per_timestamp,
784 offset,
785 },
786 |retry_state| {
787 json!({
788 "1": { "type": "TEXT", "value": retry_state.resume_after },
789 "2": { "type": "FIXED", "value": min_minor_str_ref },
790 "3": { "type": "FIXED", "value": min_minor_str_ref },
791 "4": { "type": "FIXED", "value": min_minor_str_ref },
792 "5": { "type": "FIXED", "value": min_patch_str_ref },
793 "6": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
794 "7": { "type": "FIXED", "value": retry_state.offset.to_string() }
795 })
796 },
797 &[
798 "request_id",
799 "device_id",
800 "time",
801 "input",
802 "settled_editable_region",
803 "example",
804 "zed_version",
805 ],
806 captured_examples_from_response,
807 )
808 .await?;
809
810 all_examples.extend(examples);
811 }
812
813 Ok(all_examples)
814}
815
816pub async fn fetch_settled_examples_after(
817 http_client: Arc<dyn HttpClient>,
818 after_timestamps: &[String],
819 max_rows_per_timestamp: Option<usize>,
820 offset: usize,
821 background_executor: BackgroundExecutor,
822 min_capture_version: Option<MinCaptureVersion>,
823) -> Result<Vec<Example>> {
824 if after_timestamps.is_empty() {
825 return Ok(Vec::new());
826 }
827
828 let progress = Progress::global();
829
830 let mut all_examples = Vec::new();
831
832 for after_date in after_timestamps.iter() {
833 let step_progress_name = format!("settled>{after_date}");
834 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
835 step_progress.set_substatus("querying");
836
837 let _ = min_capture_version;
838
839 let statement = indoc! {r#"
840 SELECT
841 ep_request_id AS request_id,
842 device_id AS device_id,
843 requested_at::string AS continuation_time,
844 requested_at::string AS time,
845 input_payload AS input,
846 requested_output AS requested_output,
847 settled_editable_region AS settled_editable_region,
848 requested_format AS requested_format,
849 zed_version AS zed_version
850 FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
851 WHERE settled_editable_region IS NOT NULL
852 AND requested_at > TRY_TO_TIMESTAMP_NTZ(?)
853 ORDER BY requested_at ASC
854 LIMIT ?
855 OFFSET ?
856 "#};
857
858 let examples = fetch_examples_with_query(
859 http_client.clone(),
860 &step_progress,
861 background_executor.clone(),
862 statement,
863 QueryRetryState {
864 resume_after: after_date.clone(),
865 remaining_limit: max_rows_per_timestamp,
866 offset,
867 },
868 |retry_state| {
869 json!({
870 "1": { "type": "TEXT", "value": retry_state.resume_after },
871 "2": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
872 "3": { "type": "FIXED", "value": retry_state.offset.to_string() }
873 })
874 },
875 &[
876 "request_id",
877 "device_id",
878 "time",
879 "input",
880 "requested_output",
881 "settled_editable_region",
882 "requested_format",
883 "zed_version",
884 ],
885 settled_examples_from_response,
886 )
887 .await?;
888
889 all_examples.extend(examples);
890 }
891
892 Ok(all_examples)
893}
894
895pub async fn fetch_rated_examples_after(
896 http_client: Arc<dyn HttpClient>,
897 inputs: &[(String, Option<EditPredictionRating>)],
898 max_rows_per_timestamp: Option<usize>,
899 offset: usize,
900 background_executor: BackgroundExecutor,
901 _min_capture_version: Option<MinCaptureVersion>,
902) -> Result<Vec<Example>> {
903 if inputs.is_empty() {
904 return Ok(Vec::new());
905 }
906
907 let progress = Progress::global();
908
909 let mut all_examples = Vec::new();
910
911 for (after_date, rating_filter) in inputs.iter() {
912 let filter_label = match rating_filter {
913 None => "",
914 Some(EditPredictionRating::Positive) => ":positive",
915 Some(EditPredictionRating::Negative) => ":negative",
916 };
917 let step_progress_name = format!("rated{filter_label}>{after_date}");
918 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
919 step_progress.set_substatus("querying");
920
921 let rating_value = rating_filter.as_ref().map(|rating| match rating {
922 EditPredictionRating::Positive => "Positive",
923 EditPredictionRating::Negative => "Negative",
924 });
925
926 let statement = indoc! {r#"
927 SELECT
928 ep_request_id AS request_id,
929 rated_inputs AS inputs,
930 rated_output AS output,
931 rating AS rating,
932 feedback AS feedback,
933 device_id AS device_id,
934 requested_at::string AS continuation_time,
935 requested_at::string AS time,
936 NULL AS experiment_name,
937 NULL AS environment,
938 zed_version AS zed_version
939 FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
940 WHERE rating IS NOT NULL
941 AND (? IS NULL OR rating = ?)
942 AND requested_at > TRY_TO_TIMESTAMP_NTZ(?)
943 AND rated_inputs IS NOT NULL
944 AND rated_inputs:cursor_excerpt IS NOT NULL
945 AND rated_output IS NOT NULL
946 ORDER BY requested_at ASC
947 LIMIT ?
948 OFFSET ?
949 "#};
950
951 let examples = fetch_examples_with_query(
952 http_client.clone(),
953 &step_progress,
954 background_executor.clone(),
955 statement,
956 QueryRetryState {
957 resume_after: after_date.clone(),
958 remaining_limit: max_rows_per_timestamp,
959 offset,
960 },
961 |retry_state| {
962 json!({
963 "1": { "type": "TEXT", "value": rating_value },
964 "2": { "type": "TEXT", "value": rating_value },
965 "3": { "type": "TEXT", "value": retry_state.resume_after },
966 "4": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
967 "5": { "type": "FIXED", "value": retry_state.offset.to_string() }
968 })
969 },
970 &[
971 "request_id",
972 "inputs",
973 "output",
974 "rating",
975 "feedback",
976 "device_id",
977 "time",
978 "experiment_name",
979 "environment",
980 "zed_version",
981 ],
982 rated_examples_from_response,
983 )
984 .await?;
985
986 all_examples.extend(examples);
987 }
988
989 Ok(all_examples)
990}
991
992fn rated_examples_from_response<'a>(
993 response: &'a SnowflakeStatementResponse,
994 column_indices: &'a std::collections::HashMap<String, usize>,
995) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
996 if let Some(code) = &response.code {
997 if code != SNOWFLAKE_SUCCESS_CODE {
998 anyhow::bail!(
999 "snowflake sql api returned error code={code} message={}",
1000 response.message.as_deref().unwrap_or("<no message>")
1001 );
1002 }
1003 }
1004
1005 let iter = response
1006 .data
1007 .iter()
1008 .enumerate()
1009 .filter_map(move |(row_index, data_row)| {
1010 let get_string = |name: &str| -> Option<String> {
1011 let index = column_indices.get(name).copied()?;
1012 match data_row.get(index)? {
1013 JsonValue::String(s) => Some(s.clone()),
1014 JsonValue::Null => None,
1015 other => Some(other.to_string()),
1016 }
1017 };
1018
1019 let get_json = |name: &str| -> Option<JsonValue> {
1020 let index = column_indices.get(name).copied()?;
1021 let value = data_row.get(index)?;
1022 if value.is_null() {
1023 return None;
1024 }
1025 match value {
1026 JsonValue::String(s) => serde_json::from_str(s).ok(),
1027 other => Some(other.clone()),
1028 }
1029 };
1030
1031 let request_id = get_string("request_id");
1032 let inputs_json = get_json("inputs");
1033 let inputs: Option<ZetaPromptInput> = match &inputs_json {
1034 Some(v) => match serde_json::from_value(v.clone()) {
1035 Ok(parsed) => Some(parsed),
1036 Err(e) => {
1037 log::warn!(
1038 "skipping row {row_index}: failed to parse inputs - {e}",
1039 );
1040 return None;
1041 }
1042 },
1043 None => None,
1044 };
1045 let output = get_string("output");
1046 let rating = get_string("rating");
1047 let feedback = get_string("feedback").unwrap_or_default();
1048 let device_id = get_string("device_id");
1049 let time = get_string("time");
1050 let experiment_name = get_string("experiment_name");
1051 let environment = get_string("environment");
1052 let zed_version = get_string("zed_version");
1053
1054 match (inputs, output.clone(), rating.clone(), time.clone()) {
1055 (Some(inputs), Some(output), Some(rating), Some(time)) => {
1056 Some(build_rated_example(
1057 request_id,
1058 device_id.unwrap_or_default(),
1059 time,
1060 inputs,
1061 output,
1062 rating,
1063 feedback,
1064 experiment_name,
1065 environment,
1066 zed_version,
1067 ))
1068 }
1069 _ => {
1070 log::warn!(
1071 "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} time={:?}",
1072 inputs_json.is_some(),
1073 output.is_some(),
1074 rating.is_some(),
1075 time.is_some(),
1076 );
1077 None
1078 }
1079 }
1080 });
1081
1082 Ok(Box::new(iter))
1083}
1084
1085fn build_rated_example(
1086 request_id: Option<String>,
1087 device_id: String,
1088 time: String,
1089 input: ZetaPromptInput,
1090 output: String,
1091 rating: String,
1092 feedback: String,
1093 experiment_name: Option<String>,
1094 environment: Option<String>,
1095 zed_version: Option<String>,
1096) -> Example {
1097 let parsed_rating = if rating == "Positive" {
1098 EditPredictionRating::Positive
1099 } else {
1100 EditPredictionRating::Negative
1101 };
1102 let is_positive = parsed_rating == EditPredictionRating::Positive;
1103 let request_id = request_id.unwrap_or_else(|| format!("rated-{}-{}", device_id, time));
1104
1105 let mut tags = Vec::with_capacity(3);
1106 tags.push(if is_positive {
1107 "rated:positive".to_string()
1108 } else {
1109 "rated:negative".to_string()
1110 });
1111 if let Some(experiment) = experiment_name {
1112 tags.push(format!("experiment:{experiment}"));
1113 }
1114 if let Some(env) = environment {
1115 tags.push(format!("environment:{env}"));
1116 }
1117
1118 let mut example =
1119 build_example_from_snowflake(request_id, device_id, time, input, tags, None, zed_version);
1120
1121 example.spec.rating = Some(parsed_rating);
1122
1123 if !feedback.is_empty() {
1124 example
1125 .spec
1126 .human_feedback
1127 .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
1128 }
1129
1130 if is_positive {
1131 example.spec.expected_patches = vec![output];
1132 } else {
1133 example.spec.rejected_patch = Some(output);
1134 }
1135
1136 example
1137}
1138
1139fn requested_examples_from_response<'a>(
1140 response: &'a SnowflakeStatementResponse,
1141 column_indices: &'a std::collections::HashMap<String, usize>,
1142) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1143 if let Some(code) = &response.code {
1144 if code != SNOWFLAKE_SUCCESS_CODE {
1145 anyhow::bail!(
1146 "snowflake sql api returned error code={code} message={}",
1147 response.message.as_deref().unwrap_or("<no message>")
1148 );
1149 }
1150 }
1151
1152 let iter = response
1153 .data
1154 .iter()
1155 .enumerate()
1156 .filter_map(move |(row_index, data_row)| {
1157 let get_string = |name: &str| -> Option<String> {
1158 let index = column_indices.get(name).copied()?;
1159 match data_row.get(index)? {
1160 JsonValue::String(s) => Some(s.clone()),
1161 JsonValue::Null => None,
1162 other => Some(other.to_string()),
1163 }
1164 };
1165
1166 let get_json = |name: &str| -> Option<JsonValue> {
1167 let index = column_indices.get(name).copied()?;
1168 let value = data_row.get(index)?;
1169 if value.is_null() {
1170 return None;
1171 }
1172 match value {
1173 JsonValue::String(s) => serde_json::from_str(s).ok(),
1174 other => Some(other.clone()),
1175 }
1176 };
1177
1178 let request_id_str = get_string("request_id");
1179 let device_id = get_string("device_id");
1180 let time = get_string("time");
1181 let input_json = get_json("input");
1182 let input: Option<ZetaPromptInput> =
1183 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1184 let zed_version = get_string("zed_version");
1185
1186 match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
1187 (Some(request_id), Some(device_id), Some(time), Some(input)) => {
1188 Some(build_example_from_snowflake(
1189 request_id,
1190 device_id,
1191 time,
1192 input,
1193 vec!["requested".to_string()],
1194 None,
1195 zed_version,
1196 ))
1197 }
1198 _ => {
1199 log::warn!(
1200 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
1201 request_id_str.is_some(),
1202 device_id.is_some(),
1203 time.is_some(),
1204 input_json.is_some(),
1205 );
1206 None
1207 }
1208 }
1209 });
1210
1211 Ok(Box::new(iter))
1212}
1213
1214fn settled_examples_from_response<'a>(
1215 response: &'a SnowflakeStatementResponse,
1216 column_indices: &'a std::collections::HashMap<String, usize>,
1217) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1218 if let Some(code) = &response.code {
1219 if code != SNOWFLAKE_SUCCESS_CODE {
1220 anyhow::bail!(
1221 "snowflake sql api returned error code={code} message={}",
1222 response.message.as_deref().unwrap_or("<no message>")
1223 );
1224 }
1225 }
1226
1227 let iter = response
1228 .data
1229 .iter()
1230 .enumerate()
1231 .filter_map(move |(row_index, data_row)| {
1232 let get_value = |name: &str| -> Option<JsonValue> {
1233 let index = column_indices.get(name).copied()?;
1234 let value = data_row.get(index)?;
1235 if value.is_null() {
1236 None
1237 } else {
1238 Some(value.clone())
1239 }
1240 };
1241
1242 let get_string = |name: &str| -> Option<String> {
1243 match get_value(name)? {
1244 JsonValue::String(s) => Some(s),
1245 other => Some(other.to_string()),
1246 }
1247 };
1248
1249 let parse_json_value = |raw: Option<&JsonValue>| -> Option<JsonValue> {
1250 let value = raw?;
1251 match value {
1252 JsonValue::String(s) => serde_json::from_str::<JsonValue>(s).ok(),
1253 other => Some(other.clone()),
1254 }
1255 };
1256
1257 let request_id_str = get_string("request_id");
1258 let device_id = get_string("device_id");
1259 let time = get_string("time");
1260 let input_raw = get_value("input");
1261 let input_json = parse_json_value(input_raw.as_ref());
1262 let input: Option<ZetaPromptInput> = input_json
1263 .as_ref()
1264 .and_then(|parsed| serde_json::from_value(parsed.clone()).ok());
1265 let requested_output = get_string("requested_output");
1266 let settled_editable_region = get_string("settled_editable_region");
1267 let requested_format =
1268 get_string("requested_format").and_then(|s| ZetaFormat::parse(&s).ok());
1269 let zed_version = get_string("zed_version");
1270
1271 match (
1272 request_id_str.clone(),
1273 device_id.clone(),
1274 time.clone(),
1275 input.clone(),
1276 requested_output.clone(),
1277 settled_editable_region.clone(),
1278 requested_format,
1279 ) {
1280 (
1281 Some(request_id),
1282 Some(device_id),
1283 Some(time),
1284 Some(input),
1285 Some(requested_output),
1286 Some(settled_editable_region),
1287 Some(requested_format),
1288 ) => Some(build_settled_example(
1289 request_id,
1290 device_id,
1291 time,
1292 input,
1293 requested_output,
1294 settled_editable_region,
1295 requested_format,
1296 zed_version,
1297 )),
1298 _ => {
1299 let mut missing_fields = Vec::new();
1300
1301 if request_id_str.is_none() {
1302 missing_fields.push("request_id");
1303 }
1304 if device_id.is_none() {
1305 missing_fields.push("device_id");
1306 }
1307 if time.is_none() {
1308 missing_fields.push("time");
1309 }
1310 if input_raw.is_none() || input_json.is_none() || input.is_none() {
1311 missing_fields.push("input");
1312 }
1313 if requested_output.is_none() {
1314 missing_fields.push("requested_output");
1315 }
1316 if settled_editable_region.is_none() {
1317 missing_fields.push("settled_editable_region");
1318 }
1319 if requested_format.is_none() {
1320 missing_fields.push("requested_format");
1321 }
1322
1323 log::warn!(
1324 "skipping settled row {row_index}: [{}]",
1325 missing_fields.join(", "),
1326 );
1327 None
1328 }
1329 }
1330 });
1331
1332 Ok(Box::new(iter))
1333}
1334
1335fn captured_examples_from_response<'a>(
1336 response: &'a SnowflakeStatementResponse,
1337 column_indices: &'a std::collections::HashMap<String, usize>,
1338) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1339 if let Some(code) = &response.code {
1340 if code != SNOWFLAKE_SUCCESS_CODE {
1341 anyhow::bail!(
1342 "snowflake sql api returned error code={code} message={}",
1343 response.message.as_deref().unwrap_or("<no message>")
1344 );
1345 }
1346 }
1347
1348 let iter = response
1349 .data
1350 .iter()
1351 .enumerate()
1352 .filter_map(move |(row_index, data_row)| {
1353 let get_value = |name: &str| -> Option<JsonValue> {
1354 let index = column_indices.get(name).copied()?;
1355 let value = data_row.get(index)?;
1356 if value.is_null() {
1357 None
1358 } else {
1359 Some(value.clone())
1360 }
1361 };
1362
1363 let get_string = |name: &str| -> Option<String> {
1364 match get_value(name)? {
1365 JsonValue::String(s) => Some(s),
1366 other => Some(other.to_string()),
1367 }
1368 };
1369
1370 let parse_json_value = |raw: Option<&JsonValue>| -> Option<JsonValue> {
1371 let value = raw?;
1372 match value {
1373 JsonValue::String(s) => serde_json::from_str::<JsonValue>(s).ok(),
1374 other => Some(other.clone()),
1375 }
1376 };
1377
1378 let request_id = get_string("request_id");
1379 let device_id = get_string("device_id");
1380 let time = get_string("time");
1381 let input_raw = get_value("input");
1382 let input_json = parse_json_value(input_raw.as_ref());
1383 let input: Option<ZetaPromptInput> = input_json
1384 .as_ref()
1385 .and_then(|parsed| serde_json::from_value(parsed.clone()).ok());
1386 let example_raw = get_value("example");
1387 let example_json = parse_json_value(example_raw.as_ref());
1388 let example_spec: Option<ExampleSpec> = example_json.as_ref().and_then(|parsed| {
1389 serde_json::from_value(parsed.clone())
1390 .or_else(|_| {
1391 parsed
1392 .as_str()
1393 .and_then(|markdown| ExampleSpec::from_markdown(markdown).ok())
1394 .ok_or_else(|| {
1395 serde_json::Error::io(std::io::Error::other("not markdown"))
1396 })
1397 })
1398 .ok()
1399 });
1400 let has_example_spec = example_spec.is_some();
1401 let settled_editable_region = get_string("settled_editable_region");
1402 let zed_version = get_string("zed_version");
1403
1404 match (
1405 request_id.clone(),
1406 device_id.clone(),
1407 time.clone(),
1408 input.clone(),
1409 example_spec,
1410 settled_editable_region.clone(),
1411 ) {
1412 (
1413 Some(request_id),
1414 Some(device_id),
1415 Some(time),
1416 Some(input),
1417 Some(example_spec),
1418 Some(settled_editable_region),
1419 ) => Some(build_captured_example(
1420 request_id,
1421 device_id,
1422 time,
1423 input,
1424 example_spec,
1425 settled_editable_region,
1426 zed_version,
1427 )),
1428 _ => {
1429 let mut missing_fields = Vec::new();
1430
1431 if request_id.is_none() {
1432 missing_fields.push("request_id");
1433 }
1434 if device_id.is_none() {
1435 missing_fields.push("device_id");
1436 }
1437 if time.is_none() {
1438 missing_fields.push("time");
1439 }
1440 if input_raw.is_none() || input_json.is_none() || input.is_none() {
1441 missing_fields.push("input");
1442 }
1443 if example_raw.is_none() || !has_example_spec {
1444 missing_fields.push("example");
1445 }
1446 if settled_editable_region.is_none() {
1447 missing_fields.push("settled_editable_region");
1448 }
1449
1450 log::warn!(
1451 "skipping captured row {row_index}: [{}]",
1452 missing_fields.join(", "),
1453 );
1454 None
1455 }
1456 }
1457 });
1458
1459 Ok(Box::new(iter))
1460}
1461
1462fn build_settled_example(
1463 request_id: String,
1464 device_id: String,
1465 time: String,
1466 input: ZetaPromptInput,
1467 requested_output: String,
1468 settled_editable_region: String,
1469 requested_format: ZetaFormat,
1470 zed_version: Option<String>,
1471) -> Example {
1472 let requested_editable_range =
1473 excerpt_range_for_format(requested_format, &input.excerpt_ranges).0;
1474
1475 let base_cursor_excerpt = input.cursor_excerpt.to_string();
1476
1477 let requested_range_is_valid = requested_editable_range.start <= requested_editable_range.end
1478 && requested_editable_range.end <= base_cursor_excerpt.len();
1479 let mut example = build_example_from_snowflake(
1480 request_id.clone(),
1481 device_id,
1482 time,
1483 input,
1484 vec!["settled".to_string()],
1485 None,
1486 zed_version,
1487 );
1488
1489 if !requested_range_is_valid {
1490 log::warn!(
1491 "skipping malformed requested range for request {}: requested={:?} (base_len={})",
1492 request_id,
1493 requested_editable_range,
1494 base_cursor_excerpt.len(),
1495 );
1496 return example;
1497 }
1498
1499 let settled_replacement = settled_editable_region.as_str();
1500 let rejected_patch = build_output_patch(
1501 &example.spec.cursor_path,
1502 &base_cursor_excerpt,
1503 &requested_editable_range,
1504 &requested_output,
1505 );
1506 let expected_patch = build_output_patch(
1507 &example.spec.cursor_path,
1508 &base_cursor_excerpt,
1509 &requested_editable_range,
1510 settled_replacement,
1511 );
1512
1513 example.spec.expected_patches = vec![expected_patch];
1514 example.spec.rejected_patch = Some(rejected_patch);
1515 example
1516}
1517
1518fn build_captured_example(
1519 request_id: String,
1520 device_id: String,
1521 time: String,
1522 input: ZetaPromptInput,
1523 mut example_spec: ExampleSpec,
1524 settled_editable_region: String,
1525 zed_version: Option<String>,
1526) -> Example {
1527 let expected_patch = build_output_patch(
1528 &input.cursor_path,
1529 input.cursor_excerpt.as_ref(),
1530 &input.excerpt_ranges.editable_350,
1531 settled_editable_region.as_str(),
1532 );
1533
1534 example_spec.expected_patches = vec![expected_patch];
1535 example_spec.telemetry = Some(TelemetrySource {
1536 request_id,
1537 device_id,
1538 time,
1539 rejection_reason: String::new(),
1540 was_shown: false,
1541 });
1542
1543 Example {
1544 spec: example_spec,
1545 zed_version,
1546 prompt_inputs: Some(input),
1547 prompt: None,
1548 predictions: Vec::new(),
1549 score: Vec::new(),
1550 qa: Vec::new(),
1551 state: None,
1552 }
1553}
1554
1555fn rejected_examples_from_response<'a>(
1556 response: &'a SnowflakeStatementResponse,
1557 column_indices: &'a std::collections::HashMap<String, usize>,
1558) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1559 if let Some(code) = &response.code {
1560 if code != SNOWFLAKE_SUCCESS_CODE {
1561 anyhow::bail!(
1562 "snowflake sql api returned error code={code} message={}",
1563 response.message.as_deref().unwrap_or("<no message>")
1564 );
1565 }
1566 }
1567
1568 let iter = response
1569 .data
1570 .iter()
1571 .enumerate()
1572 .filter_map(move |(row_index, data_row)| {
1573 let get_string = |name: &str| -> Option<String> {
1574 let index = column_indices.get(name).copied()?;
1575 match data_row.get(index)? {
1576 JsonValue::String(s) => Some(s.clone()),
1577 JsonValue::Null => None,
1578 other => Some(other.to_string()),
1579 }
1580 };
1581
1582 let get_json = |name: &str| -> Option<JsonValue> {
1583 let index = column_indices.get(name).copied()?;
1584 let value = data_row.get(index)?;
1585 if value.is_null() {
1586 return None;
1587 }
1588 match value {
1589 JsonValue::String(s) => serde_json::from_str(s).ok(),
1590 other => Some(other.clone()),
1591 }
1592 };
1593
1594 let get_bool = |name: &str| -> Option<bool> {
1595 let index = column_indices.get(name).copied()?;
1596 match data_row.get(index)? {
1597 JsonValue::Bool(b) => Some(*b),
1598 JsonValue::String(s) => s.parse().ok(),
1599 _ => None,
1600 }
1601 };
1602
1603 let request_id_str = get_string("request_id");
1604 let device_id = get_string("device_id");
1605 let time = get_string("time");
1606 let input_json = get_json("input");
1607 let input: Option<ZetaPromptInput> =
1608 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1609 let prompt = get_string("prompt");
1610 let output = get_string("output");
1611 let was_shown = get_bool("was_shown");
1612 let reason = get_string("reason");
1613 let zed_version = get_string("zed_version");
1614
1615 match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1616 (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1617 Some(build_rejected_example(
1618 request_id,
1619 device_id,
1620 time,
1621 input,
1622 prompt,
1623 output,
1624 was_shown,
1625 reason,
1626 zed_version,
1627 ))
1628 }
1629 _ => {
1630 log::warn!(
1631 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1632 request_id_str.is_some(),
1633 device_id.is_some(),
1634 time.is_some(),
1635 input_json.is_some(),
1636 output.is_some(),
1637 was_shown.is_some(),
1638 reason.is_some()
1639 );
1640 None
1641 }
1642 }
1643 });
1644
1645 Ok(Box::new(iter))
1646}
1647
1648fn build_rejected_example(
1649 request_id: String,
1650 device_id: String,
1651 time: String,
1652 input: ZetaPromptInput,
1653 prompt: Option<String>,
1654 output: String,
1655 was_shown: bool,
1656 reason: String,
1657 zed_version: Option<String>,
1658) -> Example {
1659 let rejected_patch = build_output_patch(
1660 &input.cursor_path,
1661 input.cursor_excerpt.as_ref(),
1662 &input.excerpt_ranges.editable_350,
1663 &output,
1664 );
1665 let mut example = build_example_from_snowflake(
1666 request_id,
1667 device_id,
1668 time,
1669 input,
1670 vec![format!("rejection:{}", reason.to_lowercase())],
1671 Some(RejectionInfo { reason, was_shown }),
1672 zed_version,
1673 );
1674 example.spec.rejected_patch = Some(rejected_patch);
1675 example.prompt = prompt.map(|prompt| ExamplePrompt {
1676 input: prompt,
1677 expected_output: None,
1678 rejected_output: Some(output),
1679 prefill: None,
1680 provider: PredictionProvider::default(),
1681 });
1682 example
1683}
1684
1685struct RejectionInfo {
1686 reason: String,
1687 was_shown: bool,
1688}
1689
1690fn build_example_from_snowflake(
1691 request_id: String,
1692 device_id: String,
1693 time: String,
1694 input: ZetaPromptInput,
1695 tags: Vec<String>,
1696 rejection: Option<RejectionInfo>,
1697 zed_version: Option<String>,
1698) -> Example {
1699 let cursor_excerpt = input.cursor_excerpt.as_ref();
1700 let cursor_offset = input.cursor_offset_in_excerpt;
1701
1702 let mut edit_history = String::new();
1703 for event in &input.events {
1704 zeta_prompt::write_event(&mut edit_history, event);
1705 edit_history.push('\n');
1706 }
1707
1708 let (rejection_reason, was_shown) = match &rejection {
1709 Some(r) => (r.reason.clone(), r.was_shown),
1710 None => (String::new(), false),
1711 };
1712
1713 let spec = ExampleSpec {
1714 name: request_id.clone(),
1715 repository_url: String::new(),
1716 revision: String::new(),
1717 tags,
1718 reasoning: None,
1719 uncommitted_diff: String::new(),
1720 cursor_path: input.cursor_path.clone(),
1721 cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1722 edit_history,
1723 expected_patches: Vec::new(),
1724 rejected_patch: None,
1725 telemetry: Some(TelemetrySource {
1726 request_id,
1727 device_id,
1728 time,
1729 rejection_reason,
1730 was_shown,
1731 }),
1732 human_feedback: Vec::new(),
1733 rating: None,
1734 };
1735
1736 Example {
1737 spec,
1738 zed_version,
1739 prompt_inputs: Some(input),
1740 prompt: None,
1741 predictions: Vec::new(),
1742 score: Vec::new(),
1743 qa: Vec::new(),
1744 state: None,
1745 }
1746}
1747
1748fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1749 let before = &excerpt[..cursor_offset.min(excerpt.len())];
1750 let after = &excerpt[cursor_offset.min(excerpt.len())..];
1751 format!("{}[CURSOR_POSITION]{}", before, after)
1752}
1753
1754fn build_output_patch(
1755 cursor_path: &std::path::Path,
1756 cursor_excerpt: &str,
1757 editable_range: &std::ops::Range<usize>,
1758 model_output: &str,
1759) -> String {
1760 let old_text = &cursor_excerpt[editable_range.clone()];
1761
1762 let editable_start_row = cursor_excerpt[..editable_range.start]
1763 .chars()
1764 .filter(|&c| c == '\n')
1765 .count() as u32;
1766
1767 let diff_body = language::unified_diff_with_offsets(
1768 old_text,
1769 model_output,
1770 editable_start_row,
1771 editable_start_row,
1772 );
1773
1774 let mut patch = String::new();
1775 writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1776 writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1777 patch.push_str(&diff_body);
1778 patch
1779}
1780
1781fn is_timeout_response(response: &SnowflakeStatementResponse) -> bool {
1782 response.code.as_deref() == Some(SNOWFLAKE_TIMEOUT_CODE)
1783 && response
1784 .message
1785 .as_deref()
1786 .map(|message| message.to_ascii_lowercase().contains("timeout"))
1787 .unwrap_or(false)
1788}
1789
1790fn is_snowflake_timeout_error(error: &anyhow::Error) -> bool {
1791 error
1792 .chain()
1793 .any(|cause| cause.to_string().contains(SNOWFLAKE_TIMEOUT_CODE))
1794}
1795
1796fn last_continuation_timestamp_from_response(
1797 response: &SnowflakeStatementResponse,
1798 column_indices: &HashMap<String, usize>,
1799) -> Option<String> {
1800 let continuation_time_index = column_indices.get("continuation_time").copied()?;
1801 response
1802 .data
1803 .iter()
1804 .rev()
1805 .find_map(|row| match row.get(continuation_time_index)? {
1806 JsonValue::String(value) => Some(value.clone()),
1807 JsonValue::Null => None,
1808 other => Some(other.to_string()),
1809 })
1810}
1811
1812pub(crate) fn get_column_indices(
1813 meta: &Option<SnowflakeResultSetMetaData>,
1814 names: &[&str],
1815) -> HashMap<String, usize> {
1816 let mut indices = HashMap::new();
1817 if let Some(meta) = meta {
1818 for (index, col) in meta.row_type.iter().enumerate() {
1819 for &name in names {
1820 if col.name.eq_ignore_ascii_case(name) {
1821 indices.insert(name.to_string(), index);
1822 }
1823 }
1824 }
1825 }
1826 indices
1827}