@@ -30,6 +30,18 @@ pub struct ExampleSpec {
pub rejected_patch: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub captured_prompt_input: Option<CapturedPromptInput>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub telemetry: Option<TelemetrySource>,
+}
+
+/// Metadata for examples sourced from production telemetry (rejected predictions).
+#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
+pub struct TelemetrySource {
+ pub request_id: String,
+ pub device_id: String,
+ pub time: String,
+ pub rejection_reason: String,
+ pub was_shown: bool,
}
/// All data needed to run format_prompt without loading the project.
@@ -239,6 +251,7 @@ impl ExampleSpec {
expected_patches: Vec::new(),
rejected_patch: None,
captured_prompt_input: None,
+ telemetry: None,
};
if let Some(rest) = input.strip_prefix("+++\n")
@@ -486,6 +499,7 @@ mod tests {
expected_patches: Vec::new(),
rejected_patch: None,
captured_prompt_input: None,
+ telemetry: None,
};
// Cursor before `42`
@@ -620,6 +634,7 @@ mod tests {
expected_patches: Vec::new(),
rejected_patch: None,
captured_prompt_input: None,
+ telemetry: None,
};
// Cursor before `42` using inline marker
@@ -105,8 +105,11 @@ Inputs can be file paths or special specifiers:
captured-after:{timestamp}
Fetch captured examples from Snowflake after the given RFC3339 timestamp.
+ These are examples captured via the "Capture Edit Prediction Example" action.
- You can specify this multiple times and mix it with file inputs.
+ rejected-after:{timestamp}
+ Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
+ These are predictions that were shown to users but rejected (useful for DPO training).
Required environment variables to connect to Snowflake:
EP_SNOWFLAKE_API_KEY
@@ -117,20 +120,23 @@ Inputs can be file paths or special specifiers:
Examples:
- # Predict from a file
- ep predict examples.jsonl
+ # Read examples from a file
+ ep read examples.jsonl -o output.jsonl
- # Predict from captured examples after a timestamp
- ep predict captured-after:2025-01-01T00:00:00Z
+ # Read captured examples after a timestamp
+ ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
- # Mix file inputs and captured-after in the same invocation
+ # Read rejected predictions for DPO training
+ ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
+
+ # Mix multiple input sources
ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
"#;
#[derive(Subcommand, Debug, Clone)]
enum Command {
- /// Parse markdown examples and output a combined .jsonl file
- ParseExample,
+ /// Read examples from files or fetch from Snowflake, output as .jsonl
+ Read,
/// Create git worktrees for each example and load file contents
LoadProject,
/// Retrieve context for input examples.
@@ -168,7 +174,7 @@ enum Command {
impl Display for Command {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
- Command::ParseExample => write!(f, "parse-example"),
+ Command::Read => write!(f, "read"),
Command::LoadProject => write!(f, "load-project"),
Command::Context => write!(f, "context"),
Command::FormatPrompt(args) => {
@@ -357,12 +363,17 @@ async fn load_examples(
background_executor: BackgroundExecutor,
) -> anyhow::Result<Vec<Example>> {
let mut captured_after_timestamps = Vec::new();
+ let mut rejected_after_timestamps = Vec::new();
let mut file_inputs = Vec::new();
for input in &args.inputs {
let input_string = input.to_string_lossy();
if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
captured_after_timestamps.push(timestamp.to_string());
+ } else if let Some(timestamp) =
+ pull_examples::parse_rejected_after_input(input_string.as_ref())
+ {
+ rejected_after_timestamps.push(timestamp.to_string());
} else {
file_inputs.push(input.clone());
}
@@ -377,21 +388,36 @@ async fn load_examples(
if let Some(0) = remaining_limit_for_snowflake {
log::info!(
- "skipping captured-after inputs because --limit is already satisfied by example files"
+ "skipping Snowflake inputs because --limit is already satisfied by example files"
);
- } else if !captured_after_timestamps.is_empty() {
- captured_after_timestamps.sort();
-
+ } else {
let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
- let mut captured_examples = pull_examples::fetch_captured_examples_after(
- http_client,
- &captured_after_timestamps,
- max_rows_per_timestamp,
- background_executor,
- )
- .await?;
- examples.append(&mut captured_examples);
+ if !captured_after_timestamps.is_empty() {
+ captured_after_timestamps.sort();
+
+ let mut captured_examples = pull_examples::fetch_captured_examples_after(
+ http_client.clone(),
+ &captured_after_timestamps,
+ max_rows_per_timestamp,
+ background_executor.clone(),
+ )
+ .await?;
+ examples.append(&mut captured_examples);
+ }
+
+ if !rejected_after_timestamps.is_empty() {
+ rejected_after_timestamps.sort();
+
+ let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
+ http_client,
+ &rejected_after_timestamps,
+ max_rows_per_timestamp,
+ background_executor,
+ )
+ .await?;
+ examples.append(&mut rejected_examples);
+ }
}
crate::example::sort_examples_by_repo_and_rev(&mut examples);
@@ -687,7 +713,7 @@ fn main() {
let result = async {
match &command {
- Command::ParseExample => {}
+ Command::Read => {}
Command::LoadProject => {
run_load_project(
example,
@@ -9,15 +9,21 @@ use std::io::Read;
use std::sync::Arc;
use std::time::Duration;
-use crate::{
- example::Example,
- progress::{InfoStyle, Progress, Step},
+use zeta_prompt::ZetaPromptInput;
+
+use crate::example::Example;
+use crate::progress::{InfoStyle, Progress, Step};
+use edit_prediction::example_spec::{
+ CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec,
+ TelemetrySource,
};
-use edit_prediction::example_spec::ExampleSpec;
+use std::fmt::Write as _;
const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
+const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
+const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
const POLL_INTERVAL: Duration = Duration::from_secs(2);
@@ -28,6 +34,11 @@ pub fn parse_captured_after_input(input: &str) -> Option<&str> {
input.strip_prefix("captured-after:")
}
+/// Parse an input token of the form `rejected-after:{timestamp}`.
+pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
+ input.strip_prefix("rejected-after:")
+}
+
pub async fn fetch_captured_examples_after(
http_client: Arc<dyn HttpClient>,
after_timestamps: &[String],
@@ -302,6 +313,7 @@ async fn fetch_partition(
)
.header("Accept", "application/json")
.header("Accept-Encoding", "gzip")
+ .header("User-Agent", "edit_prediction_cli")
.body(AsyncBody::empty())?;
let response = http_client
@@ -387,6 +399,7 @@ async fn run_sql(
)
.header("Content-Type", "application/json")
.header("Accept", "application/json")
+ .header("User-Agent", "edit_prediction_cli")
.body(AsyncBody::from(request_body.clone()))?;
let response = http_client
@@ -414,3 +427,401 @@ async fn run_sql(
serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
.context("failed to parse Snowflake SQL API response JSON")
}
+
+pub async fn fetch_rejected_examples_after(
+ http_client: Arc<dyn HttpClient>,
+ after_timestamps: &[String],
+ max_rows_per_timestamp: usize,
+ background_executor: BackgroundExecutor,
+) -> Result<Vec<Example>> {
+ if after_timestamps.is_empty() {
+ return Ok(Vec::new());
+ }
+
+ let progress = Progress::global();
+
+ let token = std::env::var("EP_SNOWFLAKE_API_KEY")
+ .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
+ let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
+ "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
+ )?;
+ let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
+
+ let mut all_examples = Vec::new();
+
+ for after_date in after_timestamps.iter() {
+ let step_progress_name = format!("rejected>{after_date}");
+ let step_progress = progress.start(Step::PullExamples, &step_progress_name);
+ step_progress.set_substatus("querying");
+
+ // Join rejected events with their corresponding request events to get the full context.
+ // We filter for V3 sampling data which contains the structured input we need.
+ // We also filter for predictions that were actually shown to the user (was_shown = true)
+ // to focus on explicit user rejections rather than implicit cancellations.
+ let statement = indoc! {r#"
+ SELECT
+ req.event_properties:request_id::string AS request_id,
+ req.device_id::string AS device_id,
+ req.time::string AS time,
+ req.event_properties:input AS input,
+ req.event_properties:prompt::string AS prompt,
+ req.event_properties:output::string AS output,
+ rej.event_properties:was_shown::boolean AS was_shown,
+ rej.event_properties:reason::string AS reason
+ FROM events req
+ INNER JOIN events rej
+ ON req.event_properties:request_id = rej.event_properties:request_id
+ WHERE req.event_type = ?
+ AND rej.event_type = ?
+ AND req.event_properties:version = 'V3'
+ AND rej.event_properties:was_shown = true
+ AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
+ ORDER BY req.time ASC
+ LIMIT ?
+ "#};
+
+ let request = json!({
+ "statement": statement,
+ "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+ "database": "EVENTS",
+ "schema": "PUBLIC",
+ "warehouse": "DBT",
+ "role": role,
+ "bindings": {
+ "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
+ "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
+ "3": { "type": "TEXT", "value": after_date },
+ "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
+ }
+ });
+
+ let response = run_sql_with_polling(
+ http_client.clone(),
+ &base_url,
+ &token,
+ &request,
+ &step_progress,
+ background_executor.clone(),
+ )
+ .await?;
+
+ let total_rows = response
+ .result_set_meta_data
+ .as_ref()
+ .and_then(|m| m.num_rows)
+ .unwrap_or(response.data.len() as i64);
+
+ let num_partitions = response
+ .result_set_meta_data
+ .as_ref()
+ .map(|m| m.partition_info.len())
+ .unwrap_or(1)
+ .max(1);
+
+ step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
+ step_progress.set_substatus("parsing");
+
+ all_examples.extend(rejected_examples_from_response(&response)?);
+
+ if num_partitions > 1 {
+ let statement_handle = response
+ .statement_handle
+ .as_ref()
+ .context("response has multiple partitions but no statementHandle")?;
+
+ for partition in 1..num_partitions {
+ step_progress.set_substatus(format!(
+ "fetching partition {}/{}",
+ partition + 1,
+ num_partitions
+ ));
+
+ let partition_response = fetch_partition(
+ http_client.clone(),
+ &base_url,
+ &token,
+ statement_handle,
+ partition,
+ )
+ .await?;
+
+ all_examples.extend(rejected_examples_from_response(&partition_response)?);
+ }
+ }
+
+ step_progress.set_substatus("done");
+ }
+
+ Ok(all_examples)
+}
+
+fn rejected_examples_from_response(
+ response: &SnowflakeStatementResponse,
+) -> Result<impl Iterator<Item = Example> + '_> {
+ if let Some(code) = &response.code {
+ if code != SNOWFLAKE_SUCCESS_CODE {
+ anyhow::bail!(
+ "snowflake sql api returned error code={code} message={}",
+ response.message.as_deref().unwrap_or("<no message>")
+ );
+ }
+ }
+
+ let column_indices = get_column_indices(
+ &response.result_set_meta_data,
+ &[
+ "request_id",
+ "device_id",
+ "time",
+ "input",
+ "prompt",
+ "output",
+ "was_shown",
+ "reason",
+ ],
+ );
+
+ let iter = response
+ .data
+ .iter()
+ .enumerate()
+ .filter_map(move |(row_index, data_row)| {
+ let get_string = |name: &str| -> Option<String> {
+ let index = column_indices.get(name).copied()?;
+ match data_row.get(index)? {
+ JsonValue::String(s) => Some(s.clone()),
+ JsonValue::Null => None,
+ other => Some(other.to_string()),
+ }
+ };
+
+ let get_json = |name: &str| -> Option<JsonValue> {
+ let index = column_indices.get(name).copied()?;
+ let value = data_row.get(index)?;
+ if value.is_null() {
+ return None;
+ }
+ match value {
+ JsonValue::String(s) => serde_json::from_str(s).ok(),
+ other => Some(other.clone()),
+ }
+ };
+
+ let get_bool = |name: &str| -> Option<bool> {
+ let index = column_indices.get(name).copied()?;
+ match data_row.get(index)? {
+ JsonValue::Bool(b) => Some(*b),
+ JsonValue::String(s) => s.parse().ok(),
+ _ => None,
+ }
+ };
+
+ let request_id_str = get_string("request_id");
+ let device_id = get_string("device_id");
+ let time = get_string("time");
+ let input_json = get_json("input");
+ let input: Option<ZetaPromptInput> =
+ input_json.clone().and_then(|v| serde_json::from_value(v).ok());
+ let output = get_string("output");
+ let was_shown = get_bool("was_shown");
+ let reason = get_string("reason");
+
+ match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
+ (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
+ Some(build_rejected_example(
+ request_id,
+ device_id,
+ time,
+ input,
+ output,
+ was_shown,
+ reason,
+ ))
+ }
+ _ => {
+ log::warn!(
+ "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
+ request_id_str.is_some(),
+ device_id.is_some(),
+ time.is_some(),
+ input_json.is_some(),
+ output.is_some(),
+ was_shown.is_some(),
+ reason.is_some()
+ );
+ None
+ }
+ }
+ });
+
+ Ok(iter)
+}
+
+fn build_rejected_example(
+ request_id: String,
+ device_id: String,
+ time: String,
+ input: ZetaPromptInput,
+ output: String,
+ was_shown: bool,
+ reason: String,
+) -> Example {
+ let events: Vec<CapturedEvent> = input
+ .events
+ .iter()
+ .map(|event| match event.as_ref() {
+ zeta_prompt::Event::BufferChange {
+ path,
+ old_path,
+ diff,
+ predicted,
+ in_open_source_repo,
+ } => CapturedEvent {
+ path: path.clone(),
+ old_path: old_path.clone(),
+ diff: diff.clone(),
+ predicted: *predicted,
+ in_open_source_repo: *in_open_source_repo,
+ },
+ })
+ .collect();
+
+ let related_files: Vec<CapturedRelatedFile> = input
+ .related_files
+ .iter()
+ .map(|rf| CapturedRelatedFile {
+ path: rf.path.clone(),
+ max_row: rf.max_row,
+ excerpts: rf
+ .excerpts
+ .iter()
+ .map(|e| CapturedRelatedExcerpt {
+ row_range: e.row_range.clone(),
+ text: e.text.to_string(),
+ })
+ .collect(),
+ })
+ .collect();
+
+ let cursor_excerpt = input.cursor_excerpt.as_ref();
+ let cursor_offset = input.cursor_offset_in_excerpt;
+
+ let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset);
+
+ let mut edit_history = String::new();
+ for event in &input.events {
+ zeta_prompt::write_event(&mut edit_history, event);
+ edit_history.push('\n');
+ }
+
+ let rejected_patch = build_rejected_patch(
+ &input.cursor_path,
+ cursor_excerpt,
+ &input.editable_range_in_excerpt,
+ &output,
+ );
+
+ let spec = ExampleSpec {
+ name: request_id.clone(),
+ repository_url: String::new(),
+ revision: String::new(),
+ tags: vec![format!("rejection:{}", reason.to_lowercase())],
+ reasoning: None,
+ uncommitted_diff: String::new(),
+ cursor_path: input.cursor_path.clone(),
+ cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
+ edit_history,
+ expected_patches: Vec::new(),
+ rejected_patch: Some(rejected_patch),
+ captured_prompt_input: Some(CapturedPromptInput {
+ cursor_file_content: cursor_excerpt.to_string(),
+ cursor_offset,
+ cursor_row,
+ cursor_column,
+ events,
+ related_files,
+ }),
+ telemetry: Some(TelemetrySource {
+ request_id,
+ device_id,
+ time,
+ rejection_reason: reason,
+ was_shown,
+ }),
+ };
+
+ Example {
+ spec,
+ prompt_inputs: None,
+ prompt: None,
+ predictions: Vec::new(),
+ score: Vec::new(),
+ state: None,
+ }
+}
+
+fn compute_row_column(text: &str, offset: usize) -> (u32, u32) {
+ let mut row = 0u32;
+ let mut last_newline_offset = 0;
+ for (i, c) in text.char_indices() {
+ if i >= offset {
+ break;
+ }
+ if c == '\n' {
+ row += 1;
+ last_newline_offset = i + 1;
+ }
+ }
+ let column = (offset - last_newline_offset) as u32;
+ (row, column)
+}
+
+fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
+ let before = &excerpt[..cursor_offset.min(excerpt.len())];
+ let after = &excerpt[cursor_offset.min(excerpt.len())..];
+ format!("{}[CURSOR_POSITION]{}", before, after)
+}
+
+fn build_rejected_patch(
+ cursor_path: &std::path::Path,
+ cursor_excerpt: &str,
+ editable_range: &std::ops::Range<usize>,
+ model_output: &str,
+) -> String {
+ let old_text = &cursor_excerpt[editable_range.clone()];
+
+ let editable_start_row = cursor_excerpt[..editable_range.start]
+ .chars()
+ .filter(|&c| c == '\n')
+ .count() as u32;
+
+ let diff_body = language::unified_diff_with_offsets(
+ old_text,
+ model_output,
+ editable_start_row,
+ editable_start_row,
+ );
+
+ let mut patch = String::new();
+ writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
+ writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
+ patch.push_str(&diff_body);
+ patch
+}
+
+fn get_column_indices(
+ meta: &Option<SnowflakeResultSetMetaData>,
+ names: &[&str],
+) -> std::collections::HashMap<String, usize> {
+ let mut indices = std::collections::HashMap::new();
+ if let Some(meta) = meta {
+ for (index, col) in meta.row_type.iter().enumerate() {
+ for &name in names {
+ if col.name.eq_ignore_ascii_case(name) {
+ indices.insert(name.to_string(), index);
+ }
+ }
+ }
+ }
+ indices
+}