1use anyhow::{Context as _, Result};
2use flate2::read::GzDecoder;
3use gpui::BackgroundExecutor;
4use http_client::{AsyncBody, HttpClient, Method, Request};
5use indoc::indoc;
6use serde::Deserialize;
7use serde_json::{Value as JsonValue, json};
8use std::fmt::Write as _;
9use std::io::Read;
10use std::sync::Arc;
11use std::time::Duration;
12use telemetry_events::EditPredictionRating;
13
14use zeta_prompt::{ZetaFormat, ZetaPromptInput, excerpt_range_for_format};
15
16use crate::example::Example;
17use crate::progress::{InfoStyle, Progress, Step};
18const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment";
19use edit_prediction::example_spec::{ExampleSpec, TelemetrySource};
20
21pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
22pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
23const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
24const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
25const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
26const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
27
28/// Minimum Zed version for filtering captured examples.
29/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
30/// where `zed_version >= 0.224.1`.
31#[derive(Clone, Copy, Debug)]
32pub struct MinCaptureVersion {
33 pub minor: u32,
34 pub patch: u32,
35}
36
37const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
38const SETTLED_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
39pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
40pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
41
42/// Parse an input token of the form `captured-after:{timestamp}`.
43pub fn parse_captured_after_input(input: &str) -> Option<&str> {
44 input.strip_prefix("captured-after:")
45}
46
47/// Parse an input token of the form `rejected-after:{timestamp}`.
48pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
49 input.strip_prefix("rejected-after:")
50}
51
52/// Parse an input token of the form `requested-after:{timestamp}`.
53pub fn parse_requested_after_input(input: &str) -> Option<&str> {
54 input.strip_prefix("requested-after:")
55}
56
57/// Parse an input token of the form `settled-after:{timestamp}`.
58pub fn parse_settled_after_input(input: &str) -> Option<&str> {
59 input.strip_prefix("settled-after:")
60}
61
62/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
63/// or `rated-negative-after:{timestamp}`.
64/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
65pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
66 if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
67 Some((timestamp, Some(EditPredictionRating::Positive)))
68 } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
69 Some((timestamp, Some(EditPredictionRating::Negative)))
70 } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
71 Some((timestamp, None))
72 } else {
73 None
74 }
75}
76
77#[derive(Debug, Clone, Deserialize)]
78#[serde(rename_all = "camelCase")]
79pub(crate) struct SnowflakeStatementResponse {
80 #[serde(default)]
81 pub(crate) data: Vec<Vec<JsonValue>>,
82 #[serde(default)]
83 pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
84 #[serde(default)]
85 pub(crate) code: Option<String>,
86 #[serde(default)]
87 pub(crate) message: Option<String>,
88 #[serde(default)]
89 pub(crate) statement_handle: Option<String>,
90}
91
92#[derive(Debug, Clone, Deserialize)]
93#[serde(rename_all = "camelCase")]
94pub(crate) struct SnowflakeResultSetMetaData {
95 #[serde(default, rename = "rowType")]
96 row_type: Vec<SnowflakeColumnMeta>,
97 #[serde(default)]
98 num_rows: Option<i64>,
99 #[serde(default)]
100 partition_info: Vec<SnowflakePartitionInfo>,
101}
102
103#[derive(Debug, Clone, Deserialize)]
104#[serde(rename_all = "camelCase")]
105struct SnowflakePartitionInfo {}
106
107#[derive(Debug, Clone, Deserialize)]
108struct SnowflakeColumnMeta {
109 #[serde(default)]
110 name: String,
111}
112
113async fn run_sql_with_polling(
114 http_client: Arc<dyn HttpClient>,
115 base_url: &str,
116 token: &str,
117 request: &serde_json::Value,
118 step_progress: &crate::progress::StepProgress,
119 background_executor: BackgroundExecutor,
120) -> Result<SnowflakeStatementResponse> {
121 let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
122
123 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
124 let statement_handle = response
125 .statement_handle
126 .as_ref()
127 .context("async query response missing statementHandle")?
128 .clone();
129
130 for attempt in 1..=MAX_POLL_ATTEMPTS {
131 step_progress.set_substatus(format!("polling ({attempt})"));
132
133 background_executor.timer(POLL_INTERVAL).await;
134
135 response =
136 fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
137
138 if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
139 break;
140 }
141 }
142
143 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
144 anyhow::bail!(
145 "query still running after {} poll attempts ({} seconds)",
146 MAX_POLL_ATTEMPTS,
147 MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
148 );
149 }
150 }
151
152 Ok(response)
153}
154
155pub(crate) async fn fetch_partition(
156 http_client: Arc<dyn HttpClient>,
157 base_url: &str,
158 token: &str,
159 statement_handle: &str,
160 partition: usize,
161) -> Result<SnowflakeStatementResponse> {
162 let url = format!(
163 "{}/api/v2/statements/{}?partition={}",
164 base_url.trim_end_matches('/'),
165 statement_handle,
166 partition
167 );
168
169 let http_request = Request::builder()
170 .method(Method::GET)
171 .uri(url.as_str())
172 .header("Authorization", format!("Bearer {token}"))
173 .header(
174 "X-Snowflake-Authorization-Token-Type",
175 "PROGRAMMATIC_ACCESS_TOKEN",
176 )
177 .header("Accept", "application/json")
178 .header("Accept-Encoding", "gzip")
179 .header("User-Agent", "edit_prediction_cli")
180 .body(AsyncBody::empty())?;
181
182 let response = http_client
183 .send(http_request)
184 .await
185 .context("failed to send partition request to Snowflake SQL API")?;
186
187 let status = response.status();
188 let content_encoding = response
189 .headers()
190 .get("content-encoding")
191 .and_then(|v| v.to_str().ok())
192 .map(|s| s.to_lowercase());
193
194 let body_bytes = {
195 use futures::AsyncReadExt as _;
196
197 let mut body = response.into_body();
198 let mut bytes = Vec::new();
199 body.read_to_end(&mut bytes)
200 .await
201 .context("failed to read Snowflake SQL API partition response body")?;
202 bytes
203 };
204
205 let body_bytes = if content_encoding.as_deref() == Some("gzip") {
206 let mut decoder = GzDecoder::new(&body_bytes[..]);
207 let mut decompressed = Vec::new();
208 decoder
209 .read_to_end(&mut decompressed)
210 .context("failed to decompress gzip response")?;
211 decompressed
212 } else {
213 body_bytes
214 };
215
216 if !status.is_success() && status.as_u16() != 202 {
217 let body_text = String::from_utf8_lossy(&body_bytes);
218 anyhow::bail!(
219 "snowflake sql api partition request http {}: {}",
220 status.as_u16(),
221 body_text
222 );
223 }
224
225 if body_bytes.is_empty() {
226 anyhow::bail!(
227 "snowflake sql api partition {} returned empty response body (http {})",
228 partition,
229 status.as_u16()
230 );
231 }
232
233 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
234 let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
235 format!(
236 "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
237 partition,
238 status.as_u16(),
239 body_preview
240 )
241 })
242}
243
244pub(crate) async fn run_sql(
245 http_client: Arc<dyn HttpClient>,
246 base_url: &str,
247 token: &str,
248 request: &serde_json::Value,
249) -> Result<SnowflakeStatementResponse> {
250 let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
251
252 let request_body =
253 serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
254
255 let http_request = Request::builder()
256 .method(Method::POST)
257 .uri(url.as_str())
258 .header("Authorization", format!("Bearer {token}"))
259 .header(
260 "X-Snowflake-Authorization-Token-Type",
261 "PROGRAMMATIC_ACCESS_TOKEN",
262 )
263 .header("Content-Type", "application/json")
264 .header("Accept", "application/json")
265 .header("User-Agent", "edit_prediction_cli")
266 .body(AsyncBody::from(request_body.clone()))?;
267
268 let response = http_client
269 .send(http_request)
270 .await
271 .context("failed to send request to Snowflake SQL API")?;
272
273 let status = response.status();
274 let body_bytes = {
275 use futures::AsyncReadExt as _;
276
277 let mut body = response.into_body();
278 let mut bytes = Vec::new();
279 body.read_to_end(&mut bytes)
280 .await
281 .context("failed to read Snowflake SQL API response body")?;
282 bytes
283 };
284
285 if !status.is_success() && status.as_u16() != 202 {
286 let body_text = String::from_utf8_lossy(&body_bytes);
287 anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
288 }
289
290 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
291 .context("failed to parse Snowflake SQL API response JSON")
292}
293
294pub async fn fetch_rejected_examples_after(
295 http_client: Arc<dyn HttpClient>,
296 after_timestamps: &[String],
297 max_rows_per_timestamp: usize,
298 offset: usize,
299 background_executor: BackgroundExecutor,
300 min_capture_version: Option<MinCaptureVersion>,
301) -> Result<Vec<Example>> {
302 if after_timestamps.is_empty() {
303 return Ok(Vec::new());
304 }
305
306 let progress = Progress::global();
307
308 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
309 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
310 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
311 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
312 )?;
313 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
314
315 let mut all_examples = Vec::new();
316
317 for after_date in after_timestamps.iter() {
318 let step_progress_name = format!("rejected>{after_date}");
319 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
320 step_progress.set_substatus("querying");
321
322 // Join rejected events with their corresponding request events to get the full context.
323 // We filter for V3 sampling data which contains the structured input we need.
324 // We also filter for predictions that were actually shown to the user (was_shown = true)
325 // to focus on explicit user rejections rather than implicit cancellations.
326 let statement = indoc! {r#"
327 SELECT
328 req.event_properties:request_id::string AS request_id,
329 req.device_id::string AS device_id,
330 req.time::string AS time,
331 req.event_properties:input AS input,
332 req.event_properties:prompt::string AS prompt,
333 req.event_properties:output::string AS output,
334 rej.event_properties:was_shown::boolean AS was_shown,
335 rej.event_properties:reason::string AS reason,
336 req.event_properties:zed_version::string AS zed_version
337 FROM events req
338 INNER JOIN events rej
339 ON req.event_properties:request_id = rej.event_properties:request_id
340 WHERE req.event_type = ?
341 AND rej.event_type = ?
342 AND req.event_properties:version = 'V3'
343 AND rej.event_properties:was_shown = true
344 AND req.event_properties:input:can_collect_data = true
345 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
346 AND (? IS NULL OR (
347 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
348 OR (
349 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
350 AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
351 )
352 ))
353 ORDER BY req.time ASC
354 LIMIT ?
355 OFFSET ?
356 "#};
357
358 let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
359 let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
360 let min_minor_str_ref = min_minor_str.as_deref();
361 let min_patch_str_ref = min_patch_str.as_deref();
362 let request = json!({
363 "statement": statement,
364 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
365 "database": "EVENTS",
366 "schema": "PUBLIC",
367 "warehouse": "DBT",
368 "role": role,
369 "bindings": {
370 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
371 "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
372 "3": { "type": "TEXT", "value": after_date },
373 "4": { "type": "FIXED", "value": min_minor_str_ref },
374 "5": { "type": "FIXED", "value": min_minor_str_ref },
375 "6": { "type": "FIXED", "value": min_minor_str_ref },
376 "7": { "type": "FIXED", "value": min_patch_str_ref },
377 "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
378 "9": { "type": "FIXED", "value": offset.to_string() }
379 }
380 });
381
382 let response = run_sql_with_polling(
383 http_client.clone(),
384 &base_url,
385 &token,
386 &request,
387 &step_progress,
388 background_executor.clone(),
389 )
390 .await?;
391
392 let total_rows = response
393 .result_set_meta_data
394 .as_ref()
395 .and_then(|m| m.num_rows)
396 .unwrap_or(response.data.len() as i64);
397
398 let num_partitions = response
399 .result_set_meta_data
400 .as_ref()
401 .map(|m| m.partition_info.len())
402 .unwrap_or(1)
403 .max(1);
404
405 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
406 step_progress.set_substatus("parsing");
407
408 let column_indices = get_column_indices(
409 &response.result_set_meta_data,
410 &[
411 "request_id",
412 "device_id",
413 "time",
414 "input",
415 "prompt",
416 "output",
417 "was_shown",
418 "reason",
419 "zed_version",
420 ],
421 );
422
423 all_examples.extend(rejected_examples_from_response(&response, &column_indices)?);
424
425 if num_partitions > 1 {
426 let statement_handle = response
427 .statement_handle
428 .as_ref()
429 .context("response has multiple partitions but no statementHandle")?;
430
431 for partition in 1..num_partitions {
432 step_progress.set_substatus(format!(
433 "fetching partition {}/{}",
434 partition + 1,
435 num_partitions
436 ));
437
438 let partition_response = fetch_partition(
439 http_client.clone(),
440 &base_url,
441 &token,
442 statement_handle,
443 partition,
444 )
445 .await?;
446
447 all_examples.extend(rejected_examples_from_response(
448 &partition_response,
449 &column_indices,
450 )?);
451 }
452 }
453
454 step_progress.set_substatus("done");
455 }
456
457 Ok(all_examples)
458}
459
460pub async fn fetch_requested_examples_after(
461 http_client: Arc<dyn HttpClient>,
462 after_timestamps: &[String],
463 max_rows_per_timestamp: usize,
464 offset: usize,
465 background_executor: BackgroundExecutor,
466 min_capture_version: Option<MinCaptureVersion>,
467) -> Result<Vec<Example>> {
468 if after_timestamps.is_empty() {
469 return Ok(Vec::new());
470 }
471
472 let progress = Progress::global();
473
474 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
475 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
476 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
477 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
478 )?;
479 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
480
481 let mut all_examples = Vec::new();
482
483 for after_date in after_timestamps.iter() {
484 let step_progress_name = format!("requested>{after_date}");
485 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
486 step_progress.set_substatus("querying");
487
488 let statement = indoc! {r#"
489 SELECT
490 req.event_properties:request_id::string AS request_id,
491 req.device_id::string AS device_id,
492 req.time::string AS time,
493 req.event_properties:input AS input,
494 req.event_properties:zed_version::string AS zed_version
495 FROM events req
496 WHERE req.event_type = ?
497 AND req.event_properties:version = 'V3'
498 AND req.event_properties:input:can_collect_data = true
499 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
500 AND (? IS NULL OR (
501 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
502 OR (
503 TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
504 AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
505 )
506 ))
507 ORDER BY req.time ASC
508 LIMIT ?
509 OFFSET ?
510 "#};
511
512 let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
513 let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
514 let min_minor_str_ref = min_minor_str.as_deref();
515 let min_patch_str_ref = min_patch_str.as_deref();
516 let request = json!({
517 "statement": statement,
518 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
519 "database": "EVENTS",
520 "schema": "PUBLIC",
521 "warehouse": "DBT",
522 "role": role,
523 "bindings": {
524 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
525 "2": { "type": "TEXT", "value": after_date },
526 "3": { "type": "FIXED", "value": min_minor_str_ref },
527 "4": { "type": "FIXED", "value": min_minor_str_ref },
528 "5": { "type": "FIXED", "value": min_minor_str_ref },
529 "6": { "type": "FIXED", "value": min_patch_str_ref },
530 "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
531 "8": { "type": "FIXED", "value": offset.to_string() }
532 }
533 });
534
535 let response = run_sql_with_polling(
536 http_client.clone(),
537 &base_url,
538 &token,
539 &request,
540 &step_progress,
541 background_executor.clone(),
542 )
543 .await?;
544
545 let total_rows = response
546 .result_set_meta_data
547 .as_ref()
548 .and_then(|m| m.num_rows)
549 .unwrap_or(response.data.len() as i64);
550
551 let num_partitions = response
552 .result_set_meta_data
553 .as_ref()
554 .map(|m| m.partition_info.len())
555 .unwrap_or(1)
556 .max(1);
557
558 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
559 step_progress.set_substatus("parsing");
560
561 let column_indices = get_column_indices(
562 &response.result_set_meta_data,
563 &["request_id", "device_id", "time", "input", "zed_version"],
564 );
565
566 all_examples.extend(requested_examples_from_response(
567 &response,
568 &column_indices,
569 )?);
570
571 if num_partitions > 1 {
572 let statement_handle = response
573 .statement_handle
574 .as_ref()
575 .context("response has multiple partitions but no statementHandle")?;
576
577 for partition in 1..num_partitions {
578 step_progress.set_substatus(format!(
579 "fetching partition {}/{}",
580 partition + 1,
581 num_partitions
582 ));
583
584 let partition_response = fetch_partition(
585 http_client.clone(),
586 &base_url,
587 &token,
588 statement_handle,
589 partition,
590 )
591 .await?;
592
593 all_examples.extend(requested_examples_from_response(
594 &partition_response,
595 &column_indices,
596 )?);
597 }
598 }
599
600 step_progress.set_substatus("done");
601 }
602
603 Ok(all_examples)
604}
605
606pub async fn fetch_settled_examples_after(
607 http_client: Arc<dyn HttpClient>,
608 after_timestamps: &[String],
609 max_rows_per_timestamp: usize,
610 offset: usize,
611 background_executor: BackgroundExecutor,
612 min_capture_version: Option<MinCaptureVersion>,
613) -> Result<Vec<Example>> {
614 if after_timestamps.is_empty() {
615 return Ok(Vec::new());
616 }
617
618 let progress = Progress::global();
619
620 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
621 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
622 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
623 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
624 )?;
625 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
626
627 let mut all_examples = Vec::new();
628
629 for after_date in after_timestamps.iter() {
630 let step_progress_name = format!("settled>{after_date}");
631 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
632 step_progress.set_substatus("querying");
633
634 let statement = indoc! {r#"
635 WITH requested AS (
636 SELECT
637 req.event_properties:request_id::string AS request_id,
638 req.device_id::string AS device_id,
639 req.time AS req_time,
640 req.time::string AS time,
641 req.event_properties:input AS input,
642 req.event_properties:format::string AS requested_format,
643 req.event_properties:output::string AS requested_output,
644 req.event_properties:zed_version::string AS zed_version
645 FROM events req
646 WHERE req.event_type = ?
647 AND req.event_properties:version = 'V3'
648 AND req.event_properties:input:can_collect_data = true
649 AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
650 )
651 SELECT
652 req.request_id AS request_id,
653 req.device_id AS device_id,
654 req.time AS time,
655 req.input AS input,
656 req.requested_output AS requested_output,
657 settled.event_properties:settled_editable_region::string AS settled_editable_region,
658 req.requested_format AS requested_format,
659 req.zed_version AS zed_version
660 FROM requested req
661 INNER JOIN events settled
662 ON req.request_id = settled.event_properties:request_id::string
663 WHERE settled.event_type = ?
664 ORDER BY req.req_time ASC
665 LIMIT ?
666 OFFSET ?
667 "#};
668
669 let _ = min_capture_version;
670 let request = json!({
671 "statement": statement,
672 "timeout": SETTLED_STATEMENT_TIMEOUT_SECONDS,
673 "database": "EVENTS",
674 "schema": "PUBLIC",
675 "warehouse": "DBT",
676 "role": role,
677 "bindings": {
678 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
679 "2": { "type": "TEXT", "value": after_date },
680 "3": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT },
681 "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
682 "5": { "type": "FIXED", "value": offset.to_string() }
683 }
684 });
685
686 let response = run_sql_with_polling(
687 http_client.clone(),
688 &base_url,
689 &token,
690 &request,
691 &step_progress,
692 background_executor.clone(),
693 )
694 .await?;
695
696 let total_rows = response
697 .result_set_meta_data
698 .as_ref()
699 .and_then(|m| m.num_rows)
700 .unwrap_or(response.data.len() as i64);
701
702 let num_partitions = response
703 .result_set_meta_data
704 .as_ref()
705 .map(|m| m.partition_info.len())
706 .unwrap_or(1)
707 .max(1);
708
709 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
710 step_progress.set_substatus("parsing");
711
712 let column_indices = get_column_indices(
713 &response.result_set_meta_data,
714 &[
715 "request_id",
716 "device_id",
717 "time",
718 "input",
719 "requested_output",
720 "settled_editable_region",
721 "requested_format",
722 "zed_version",
723 ],
724 );
725
726 all_examples.extend(settled_examples_from_response(&response, &column_indices)?);
727
728 if num_partitions > 1 {
729 let statement_handle = response
730 .statement_handle
731 .as_ref()
732 .context("response has multiple partitions but no statementHandle")?;
733
734 for partition in 1..num_partitions {
735 step_progress.set_substatus(format!(
736 "fetching partition {}/{}",
737 partition + 1,
738 num_partitions
739 ));
740
741 let partition_response = fetch_partition(
742 http_client.clone(),
743 &base_url,
744 &token,
745 statement_handle,
746 partition,
747 )
748 .await?;
749
750 all_examples.extend(settled_examples_from_response(
751 &partition_response,
752 &column_indices,
753 )?);
754 }
755 }
756
757 step_progress.set_substatus("done");
758 }
759
760 Ok(all_examples)
761}
762
763pub async fn fetch_rated_examples_after(
764 http_client: Arc<dyn HttpClient>,
765 inputs: &[(String, Option<EditPredictionRating>)],
766 max_rows_per_timestamp: usize,
767 offset: usize,
768 background_executor: BackgroundExecutor,
769 _min_capture_version: Option<MinCaptureVersion>,
770) -> Result<Vec<Example>> {
771 if inputs.is_empty() {
772 return Ok(Vec::new());
773 }
774
775 let progress = Progress::global();
776
777 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
778 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
779 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
780 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
781 )?;
782 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
783
784 let mut all_examples = Vec::new();
785
786 for (after_date, rating_filter) in inputs.iter() {
787 let filter_label = match rating_filter {
788 None => "",
789 Some(EditPredictionRating::Positive) => ":positive",
790 Some(EditPredictionRating::Negative) => ":negative",
791 };
792 let step_progress_name = format!("rated{filter_label}>{after_date}");
793 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
794 step_progress.set_substatus("querying");
795
796 let rating_value = rating_filter.as_ref().map(|r| match r {
797 EditPredictionRating::Positive => "Positive",
798 EditPredictionRating::Negative => "Negative",
799 });
800
801 let statement = indoc! {r#"
802 SELECT
803 rated.event_properties:request_id::string AS request_id,
804 rated.event_properties:inputs AS inputs,
805 rated.event_properties:output::string AS output,
806 rated.event_properties:rating::string AS rating,
807 rated.event_properties:feedback::string AS feedback,
808 rated.device_id::string AS device_id,
809 rated.time::string AS time,
810 deploy.event_properties:experiment_name::string AS experiment_name,
811 deploy.event_properties:environment::string AS environment,
812 rated.event_properties:zed_version::string AS zed_version
813 FROM events rated
814 LEFT JOIN events req
815 ON rated.event_properties:request_id::string = req.event_properties:request_id::string
816 AND req.event_type = ?
817 LEFT JOIN events deploy
818 ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string
819 AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string
820 AND deploy.event_type = ?
821 WHERE rated.event_type = ?
822 AND (? IS NULL OR rated.event_properties:rating::string = ?)
823 AND rated.time > TRY_TO_TIMESTAMP_NTZ(?)
824 AND rated.event_properties:inputs IS NOT NULL
825 AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
826 AND rated.event_properties:output IS NOT NULL
827 AND rated.event_properties:can_collect_data = true
828 ORDER BY rated.time ASC
829 LIMIT ?
830 OFFSET ?
831 "#};
832
833 let bindings = json!({
834 "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
835 "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
836 "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
837 "4": { "type": "TEXT", "value": rating_value },
838 "5": { "type": "TEXT", "value": rating_value },
839 "6": { "type": "TEXT", "value": after_date },
840 "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
841 "8": { "type": "FIXED", "value": offset.to_string() }
842 });
843
844 let request = json!({
845 "statement": statement,
846 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
847 "database": "EVENTS",
848 "schema": "PUBLIC",
849 "warehouse": "DBT",
850 "role": role,
851 "bindings": bindings
852 });
853
854 let response = run_sql_with_polling(
855 http_client.clone(),
856 &base_url,
857 &token,
858 &request,
859 &step_progress,
860 background_executor.clone(),
861 )
862 .await?;
863
864 let total_rows = response
865 .result_set_meta_data
866 .as_ref()
867 .and_then(|m| m.num_rows)
868 .unwrap_or(response.data.len() as i64);
869
870 let num_partitions = response
871 .result_set_meta_data
872 .as_ref()
873 .map(|m| m.partition_info.len())
874 .unwrap_or(1)
875 .max(1);
876
877 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
878 step_progress.set_substatus("parsing");
879
880 let column_indices = get_column_indices(
881 &response.result_set_meta_data,
882 &[
883 "request_id",
884 "inputs",
885 "output",
886 "rating",
887 "feedback",
888 "device_id",
889 "time",
890 "experiment_name",
891 "environment",
892 "zed_version",
893 ],
894 );
895
896 all_examples.extend(rated_examples_from_response(&response, &column_indices)?);
897
898 if num_partitions > 1 {
899 let statement_handle = response
900 .statement_handle
901 .as_ref()
902 .context("response has multiple partitions but no statementHandle")?;
903
904 for partition in 1..num_partitions {
905 step_progress.set_substatus(format!(
906 "fetching partition {}/{}",
907 partition + 1,
908 num_partitions
909 ));
910
911 let partition_response = fetch_partition(
912 http_client.clone(),
913 &base_url,
914 &token,
915 statement_handle,
916 partition,
917 )
918 .await?;
919
920 all_examples.extend(rated_examples_from_response(
921 &partition_response,
922 &column_indices,
923 )?);
924 }
925 }
926
927 step_progress.set_substatus("done");
928 }
929
930 Ok(all_examples)
931}
932
933fn rated_examples_from_response<'a>(
934 response: &'a SnowflakeStatementResponse,
935 column_indices: &'a std::collections::HashMap<String, usize>,
936) -> Result<impl Iterator<Item = Example> + 'a> {
937 if let Some(code) = &response.code {
938 if code != SNOWFLAKE_SUCCESS_CODE {
939 anyhow::bail!(
940 "snowflake sql api returned error code={code} message={}",
941 response.message.as_deref().unwrap_or("<no message>")
942 );
943 }
944 }
945
946 let iter = response
947 .data
948 .iter()
949 .enumerate()
950 .filter_map(move |(row_index, data_row)| {
951 let get_string = |name: &str| -> Option<String> {
952 let index = column_indices.get(name).copied()?;
953 match data_row.get(index)? {
954 JsonValue::String(s) => Some(s.clone()),
955 JsonValue::Null => None,
956 other => Some(other.to_string()),
957 }
958 };
959
960 let get_json = |name: &str| -> Option<JsonValue> {
961 let index = column_indices.get(name).copied()?;
962 let value = data_row.get(index)?;
963 if value.is_null() {
964 return None;
965 }
966 match value {
967 JsonValue::String(s) => serde_json::from_str(s).ok(),
968 other => Some(other.clone()),
969 }
970 };
971
972 let request_id = get_string("request_id");
973 let inputs_json = get_json("inputs");
974 let inputs: Option<ZetaPromptInput> = match &inputs_json {
975 Some(v) => match serde_json::from_value(v.clone()) {
976 Ok(parsed) => Some(parsed),
977 Err(e) => {
978 log::warn!(
979 "skipping row {row_index}: failed to parse inputs - {e}",
980 );
981 return None;
982 }
983 },
984 None => None,
985 };
986 let output = get_string("output");
987 let rating = get_string("rating");
988 let feedback = get_string("feedback").unwrap_or_default();
989 let device_id = get_string("device_id");
990 let time = get_string("time");
991 let experiment_name = get_string("experiment_name");
992 let environment = get_string("environment");
993 let zed_version = get_string("zed_version");
994
995 match (inputs, output.clone(), rating.clone(), device_id.clone(), time.clone()) {
996 (Some(inputs), Some(output), Some(rating), Some(device_id), Some(time)) => {
997 Some(build_rated_example(
998 request_id,
999 device_id,
1000 time,
1001 inputs,
1002 output,
1003 rating,
1004 feedback,
1005 experiment_name,
1006 environment,
1007 zed_version,
1008 ))
1009 }
1010 _ => {
1011 log::warn!(
1012 "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} device_id={:?} time={:?}",
1013 inputs_json.is_some(),
1014 output.is_some(),
1015 rating.is_some(),
1016 device_id.is_some(),
1017 time.is_some(),
1018 );
1019 None
1020 }
1021 }
1022 });
1023
1024 Ok(iter)
1025}
1026
1027fn build_rated_example(
1028 request_id: Option<String>,
1029 device_id: String,
1030 time: String,
1031 input: ZetaPromptInput,
1032 output: String,
1033 rating: String,
1034 feedback: String,
1035 experiment_name: Option<String>,
1036 environment: Option<String>,
1037 zed_version: Option<String>,
1038) -> Example {
1039 let parsed_rating = if rating == "Positive" {
1040 EditPredictionRating::Positive
1041 } else {
1042 EditPredictionRating::Negative
1043 };
1044 let is_positive = parsed_rating == EditPredictionRating::Positive;
1045 let request_id = request_id.unwrap_or_else(|| format!("rated-{}-{}", device_id, time));
1046
1047 let mut tags = Vec::with_capacity(3);
1048 tags.push(if is_positive {
1049 "rated:positive".to_string()
1050 } else {
1051 "rated:negative".to_string()
1052 });
1053 if let Some(experiment) = experiment_name {
1054 tags.push(format!("experiment:{experiment}"));
1055 }
1056 if let Some(env) = environment {
1057 tags.push(format!("environment:{env}"));
1058 }
1059
1060 let mut example =
1061 build_example_from_snowflake(request_id, device_id, time, input, tags, None, zed_version);
1062
1063 example.spec.rating = Some(parsed_rating);
1064
1065 if !feedback.is_empty() {
1066 example
1067 .spec
1068 .human_feedback
1069 .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
1070 }
1071
1072 if is_positive {
1073 example.spec.expected_patches = vec![output];
1074 } else {
1075 example.spec.rejected_patch = Some(output);
1076 }
1077
1078 example
1079}
1080
1081fn requested_examples_from_response<'a>(
1082 response: &'a SnowflakeStatementResponse,
1083 column_indices: &'a std::collections::HashMap<String, usize>,
1084) -> Result<impl Iterator<Item = Example> + 'a> {
1085 if let Some(code) = &response.code {
1086 if code != SNOWFLAKE_SUCCESS_CODE {
1087 anyhow::bail!(
1088 "snowflake sql api returned error code={code} message={}",
1089 response.message.as_deref().unwrap_or("<no message>")
1090 );
1091 }
1092 }
1093
1094 let iter = response
1095 .data
1096 .iter()
1097 .enumerate()
1098 .filter_map(move |(row_index, data_row)| {
1099 let get_string = |name: &str| -> Option<String> {
1100 let index = column_indices.get(name).copied()?;
1101 match data_row.get(index)? {
1102 JsonValue::String(s) => Some(s.clone()),
1103 JsonValue::Null => None,
1104 other => Some(other.to_string()),
1105 }
1106 };
1107
1108 let get_json = |name: &str| -> Option<JsonValue> {
1109 let index = column_indices.get(name).copied()?;
1110 let value = data_row.get(index)?;
1111 if value.is_null() {
1112 return None;
1113 }
1114 match value {
1115 JsonValue::String(s) => serde_json::from_str(s).ok(),
1116 other => Some(other.clone()),
1117 }
1118 };
1119
1120 let request_id_str = get_string("request_id");
1121 let device_id = get_string("device_id");
1122 let time = get_string("time");
1123 let input_json = get_json("input");
1124 let input: Option<ZetaPromptInput> =
1125 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1126 let zed_version = get_string("zed_version");
1127
1128 match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
1129 (Some(request_id), Some(device_id), Some(time), Some(input)) => {
1130 Some(build_example_from_snowflake(
1131 request_id,
1132 device_id,
1133 time,
1134 input,
1135 vec!["requested".to_string()],
1136 None,
1137 zed_version,
1138 ))
1139 }
1140 _ => {
1141 log::warn!(
1142 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
1143 request_id_str.is_some(),
1144 device_id.is_some(),
1145 time.is_some(),
1146 input_json.is_some(),
1147 );
1148 None
1149 }
1150 }
1151 });
1152
1153 Ok(iter)
1154}
1155
1156fn settled_examples_from_response<'a>(
1157 response: &'a SnowflakeStatementResponse,
1158 column_indices: &'a std::collections::HashMap<String, usize>,
1159) -> Result<impl Iterator<Item = Example> + 'a> {
1160 if let Some(code) = &response.code {
1161 if code != SNOWFLAKE_SUCCESS_CODE {
1162 anyhow::bail!(
1163 "snowflake sql api returned error code={code} message={}",
1164 response.message.as_deref().unwrap_or("<no message>")
1165 );
1166 }
1167 }
1168
1169 let iter = response
1170 .data
1171 .iter()
1172 .enumerate()
1173 .filter_map(move |(row_index, data_row)| {
1174 let get_value = |name: &str| -> Option<JsonValue> {
1175 let index = column_indices.get(name).copied()?;
1176 let value = data_row.get(index)?;
1177 if value.is_null() {
1178 None
1179 } else {
1180 Some(value.clone())
1181 }
1182 };
1183
1184 let get_string = |name: &str| -> Option<String> {
1185 match get_value(name)? {
1186 JsonValue::String(s) => Some(s),
1187 other => Some(other.to_string()),
1188 }
1189 };
1190
1191 let parse_json_value = |_: &str, raw: Option<&JsonValue>| -> Option<JsonValue> {
1192 let value = raw?;
1193 match value {
1194 JsonValue::String(s) => serde_json::from_str::<JsonValue>(s).ok(),
1195 other => Some(other.clone()),
1196 }
1197 };
1198
1199 let request_id_str = get_string("request_id");
1200 let device_id = get_string("device_id");
1201 let time = get_string("time");
1202 let input_raw = get_value("input");
1203 let input_json = parse_json_value("input", input_raw.as_ref());
1204 let input: Option<ZetaPromptInput> = input_json
1205 .as_ref()
1206 .and_then(|parsed| serde_json::from_value(parsed.clone()).ok());
1207 let requested_output = get_string("requested_output");
1208 let settled_editable_region = get_string("settled_editable_region");
1209 let requested_format =
1210 get_string("requested_format").and_then(|s| ZetaFormat::parse(&s).ok());
1211 let zed_version = get_string("zed_version");
1212
1213 match (
1214 request_id_str.clone(),
1215 device_id.clone(),
1216 time.clone(),
1217 input.clone(),
1218 requested_output.clone(),
1219 settled_editable_region.clone(),
1220 requested_format,
1221 ) {
1222 (
1223 Some(request_id),
1224 Some(device_id),
1225 Some(time),
1226 Some(input),
1227 Some(requested_output),
1228 Some(settled_editable_region),
1229 Some(requested_format),
1230 ) => Some(build_settled_example(
1231 request_id,
1232 device_id,
1233 time,
1234 input,
1235 requested_output,
1236 settled_editable_region,
1237 requested_format,
1238 zed_version,
1239 )),
1240 _ => {
1241 let mut missing_fields = Vec::new();
1242
1243 if request_id_str.is_none() {
1244 missing_fields.push("request_id");
1245 }
1246 if device_id.is_none() {
1247 missing_fields.push("device_id");
1248 }
1249 if time.is_none() {
1250 missing_fields.push("time");
1251 }
1252 if input_raw.is_none() || input_json.is_none() || input.is_none() {
1253 missing_fields.push("input");
1254 }
1255 if requested_output.is_none() {
1256 missing_fields.push("requested_output");
1257 }
1258 if settled_editable_region.is_none() {
1259 missing_fields.push("settled_editable_region");
1260 }
1261 if requested_format.is_none() {
1262 missing_fields.push("requested_format");
1263 }
1264
1265 log::warn!(
1266 "skipping settled row {row_index}: [{}]",
1267 missing_fields.join(", "),
1268 );
1269 None
1270 }
1271 }
1272 });
1273
1274 Ok(iter)
1275}
1276
1277fn build_settled_example(
1278 request_id: String,
1279 device_id: String,
1280 time: String,
1281 input: ZetaPromptInput,
1282 requested_output: String,
1283 settled_editable_region: String,
1284 requested_format: ZetaFormat,
1285 zed_version: Option<String>,
1286) -> Example {
1287 let requested_editable_range = input
1288 .excerpt_ranges
1289 .as_ref()
1290 .map(|ranges| excerpt_range_for_format(requested_format, ranges).0)
1291 .unwrap_or_else(|| input.editable_range_in_excerpt.clone());
1292
1293 let base_cursor_excerpt = input.cursor_excerpt.to_string();
1294
1295 let requested_range_is_valid = requested_editable_range.start <= requested_editable_range.end
1296 && requested_editable_range.end <= base_cursor_excerpt.len();
1297 let mut example = build_example_from_snowflake(
1298 request_id.clone(),
1299 device_id,
1300 time,
1301 input,
1302 vec!["settled".to_string()],
1303 None,
1304 zed_version,
1305 );
1306
1307 if !requested_range_is_valid {
1308 log::warn!(
1309 "skipping malformed requested range for request {}: requested={:?} (base_len={})",
1310 request_id,
1311 requested_editable_range,
1312 base_cursor_excerpt.len(),
1313 );
1314 return example;
1315 }
1316
1317 let settled_replacement = settled_editable_region.as_str();
1318 let rejected_patch = build_output_patch(
1319 &example.spec.cursor_path,
1320 &base_cursor_excerpt,
1321 &requested_editable_range,
1322 &requested_output,
1323 );
1324 let expected_patch = build_output_patch(
1325 &example.spec.cursor_path,
1326 &base_cursor_excerpt,
1327 &requested_editable_range,
1328 settled_replacement,
1329 );
1330
1331 example.spec.expected_patches = vec![expected_patch];
1332 example.spec.rejected_patch = Some(rejected_patch);
1333 example
1334}
1335
1336fn rejected_examples_from_response<'a>(
1337 response: &'a SnowflakeStatementResponse,
1338 column_indices: &'a std::collections::HashMap<String, usize>,
1339) -> Result<impl Iterator<Item = Example> + 'a> {
1340 if let Some(code) = &response.code {
1341 if code != SNOWFLAKE_SUCCESS_CODE {
1342 anyhow::bail!(
1343 "snowflake sql api returned error code={code} message={}",
1344 response.message.as_deref().unwrap_or("<no message>")
1345 );
1346 }
1347 }
1348
1349 let iter = response
1350 .data
1351 .iter()
1352 .enumerate()
1353 .filter_map(move |(row_index, data_row)| {
1354 let get_string = |name: &str| -> Option<String> {
1355 let index = column_indices.get(name).copied()?;
1356 match data_row.get(index)? {
1357 JsonValue::String(s) => Some(s.clone()),
1358 JsonValue::Null => None,
1359 other => Some(other.to_string()),
1360 }
1361 };
1362
1363 let get_json = |name: &str| -> Option<JsonValue> {
1364 let index = column_indices.get(name).copied()?;
1365 let value = data_row.get(index)?;
1366 if value.is_null() {
1367 return None;
1368 }
1369 match value {
1370 JsonValue::String(s) => serde_json::from_str(s).ok(),
1371 other => Some(other.clone()),
1372 }
1373 };
1374
1375 let get_bool = |name: &str| -> Option<bool> {
1376 let index = column_indices.get(name).copied()?;
1377 match data_row.get(index)? {
1378 JsonValue::Bool(b) => Some(*b),
1379 JsonValue::String(s) => s.parse().ok(),
1380 _ => None,
1381 }
1382 };
1383
1384 let request_id_str = get_string("request_id");
1385 let device_id = get_string("device_id");
1386 let time = get_string("time");
1387 let input_json = get_json("input");
1388 let input: Option<ZetaPromptInput> =
1389 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1390 let output = get_string("output");
1391 let was_shown = get_bool("was_shown");
1392 let reason = get_string("reason");
1393 let zed_version = get_string("zed_version");
1394
1395 match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1396 (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1397 Some(build_rejected_example(
1398 request_id,
1399 device_id,
1400 time,
1401 input,
1402 output,
1403 was_shown,
1404 reason,
1405 zed_version,
1406 ))
1407 }
1408 _ => {
1409 log::warn!(
1410 "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1411 request_id_str.is_some(),
1412 device_id.is_some(),
1413 time.is_some(),
1414 input_json.is_some(),
1415 output.is_some(),
1416 was_shown.is_some(),
1417 reason.is_some()
1418 );
1419 None
1420 }
1421 }
1422 });
1423
1424 Ok(iter)
1425}
1426
1427fn build_rejected_example(
1428 request_id: String,
1429 device_id: String,
1430 time: String,
1431 input: ZetaPromptInput,
1432 output: String,
1433 was_shown: bool,
1434 reason: String,
1435 zed_version: Option<String>,
1436) -> Example {
1437 let rejected_patch = build_output_patch(
1438 &input.cursor_path,
1439 input.cursor_excerpt.as_ref(),
1440 &input.editable_range_in_excerpt,
1441 &output,
1442 );
1443 let mut example = build_example_from_snowflake(
1444 request_id,
1445 device_id,
1446 time,
1447 input,
1448 vec![format!("rejection:{}", reason.to_lowercase())],
1449 Some(RejectionInfo { reason, was_shown }),
1450 zed_version,
1451 );
1452 example.spec.rejected_patch = Some(rejected_patch);
1453 example
1454}
1455
1456struct RejectionInfo {
1457 reason: String,
1458 was_shown: bool,
1459}
1460
1461fn build_example_from_snowflake(
1462 request_id: String,
1463 device_id: String,
1464 time: String,
1465 input: ZetaPromptInput,
1466 tags: Vec<String>,
1467 rejection: Option<RejectionInfo>,
1468 zed_version: Option<String>,
1469) -> Example {
1470 let cursor_excerpt = input.cursor_excerpt.as_ref();
1471 let cursor_offset = input.cursor_offset_in_excerpt;
1472
1473 let mut edit_history = String::new();
1474 for event in &input.events {
1475 zeta_prompt::write_event(&mut edit_history, event);
1476 edit_history.push('\n');
1477 }
1478
1479 let (rejection_reason, was_shown) = match &rejection {
1480 Some(r) => (r.reason.clone(), r.was_shown),
1481 None => (String::new(), false),
1482 };
1483
1484 let spec = ExampleSpec {
1485 name: request_id.clone(),
1486 repository_url: String::new(),
1487 revision: String::new(),
1488 tags,
1489 reasoning: None,
1490 uncommitted_diff: String::new(),
1491 cursor_path: input.cursor_path.clone(),
1492 cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1493 edit_history,
1494 expected_patches: Vec::new(),
1495 rejected_patch: None,
1496 telemetry: Some(TelemetrySource {
1497 request_id,
1498 device_id,
1499 time,
1500 rejection_reason,
1501 was_shown,
1502 }),
1503 human_feedback: Vec::new(),
1504 rating: None,
1505 };
1506
1507 Example {
1508 spec,
1509 zed_version,
1510 prompt_inputs: Some(input),
1511 prompt: None,
1512 predictions: Vec::new(),
1513 score: Vec::new(),
1514 qa: Vec::new(),
1515 state: None,
1516 }
1517}
1518
1519fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1520 let before = &excerpt[..cursor_offset.min(excerpt.len())];
1521 let after = &excerpt[cursor_offset.min(excerpt.len())..];
1522 format!("{}[CURSOR_POSITION]{}", before, after)
1523}
1524
1525fn build_output_patch(
1526 cursor_path: &std::path::Path,
1527 cursor_excerpt: &str,
1528 editable_range: &std::ops::Range<usize>,
1529 model_output: &str,
1530) -> String {
1531 let old_text = &cursor_excerpt[editable_range.clone()];
1532
1533 let editable_start_row = cursor_excerpt[..editable_range.start]
1534 .chars()
1535 .filter(|&c| c == '\n')
1536 .count() as u32;
1537
1538 let diff_body = language::unified_diff_with_offsets(
1539 old_text,
1540 model_output,
1541 editable_start_row,
1542 editable_start_row,
1543 );
1544
1545 let mut patch = String::new();
1546 writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1547 writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1548 patch.push_str(&diff_body);
1549 patch
1550}
1551
1552pub(crate) fn get_column_indices(
1553 meta: &Option<SnowflakeResultSetMetaData>,
1554 names: &[&str],
1555) -> std::collections::HashMap<String, usize> {
1556 let mut indices = std::collections::HashMap::new();
1557 if let Some(meta) = meta {
1558 for (index, col) in meta.row_type.iter().enumerate() {
1559 for &name in names {
1560 if col.name.eq_ignore_ascii_case(name) {
1561 indices.insert(name.to_string(), index);
1562 }
1563 }
1564 }
1565 }
1566 indices
1567}