ep cli: Load captured examples from Snowflake (#46102)

Agus Zubiaga created

Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/example.rs       |  11 
crates/edit_prediction_cli/src/main.rs          | 114 ++++++++-
crates/edit_prediction_cli/src/progress.rs      |  63 ++++
crates/edit_prediction_cli/src/pull_examples.rs | 220 +++++++++++++++++++
crates/edit_prediction_cli/src/split_commit.rs  |   2 
crates/edit_prediction_cli/src/synthesize.rs    |   2 
6 files changed, 376 insertions(+), 36 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/example.rs 🔗

@@ -125,17 +125,9 @@ impl Example {
     }
 }
 
-pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
+pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
     let mut examples = Vec::new();
 
-    let stdin_path: PathBuf = PathBuf::from("-");
-
-    let inputs = if inputs.is_empty() {
-        &[stdin_path]
-    } else {
-        inputs
-    };
-
     for path in inputs {
         let is_stdin = path.as_path() == Path::new("-");
         let content = if is_stdin {
@@ -201,7 +193,6 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
         }
     }
 
-    sort_examples_by_repo_and_rev(&mut examples);
     examples
 }
 

crates/edit_prediction_cli/src/main.rs 🔗

@@ -9,12 +9,12 @@ mod metrics;
 mod paths;
 mod predict;
 mod progress;
+mod pull_examples;
 mod reorder_patch;
 mod retrieve_context;
 mod score;
 mod split_commit;
 mod synthesize;
-
 use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 use edit_prediction::EditPredictionStore;
 use gpui::Application;
@@ -24,7 +24,7 @@ use std::fmt::Display;
 use std::{path::PathBuf, sync::Arc};
 
 use crate::distill::run_distill;
-use crate::example::{group_examples_by_repo, read_examples, write_examples};
+use crate::example::{Example, group_examples_by_repo, read_example_files, write_examples};
 use crate::format_prompt::run_format_prompt;
 use crate::load_project::run_load_project;
 use crate::paths::FAILED_EXAMPLES_DIR;
@@ -42,9 +42,11 @@ struct EpArgs {
     printenv: bool,
     #[clap(long, default_value_t = 10, global = true)]
     max_parallelism: usize,
+    #[clap(long, global = true)]
+    limit: Option<usize>,
     #[command(subcommand)]
     command: Option<Command>,
-    #[clap(global = true)]
+    #[clap(global = true, help = INPUTS_HELP)]
     inputs: Vec<PathBuf>,
     #[arg(long, short, global = true)]
     output: Option<PathBuf>,
@@ -54,7 +56,37 @@ struct EpArgs {
     failfast: bool,
 }
 
-#[derive(Subcommand, Debug)]
+const INPUTS_HELP: &str = r#"
+Inputs can be file paths or special specifiers:
+
+  path
+      Path to an example(s) file (.md, .json, or .jsonl)
+
+  captured-after:{timestamp}
+      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
+
+      You can specify this multiple times and mix it with file inputs.
+
+      Required environment variables to connect to Snowflake:
+          EP_SNOWFLAKE_API_KEY
+          EP_SNOWFLAKE_BASE_URL
+
+      Optional:
+          EP_SNOWFLAKE_ROLE
+
+Examples:
+
+  # Predict from a file
+  ep predict examples.jsonl
+
+  # Predict from captured examples after a timestamp
+  ep predict captured-after:2025-01-01T00:00:00Z
+
+  # Mix file inputs and captured-after in the same invocation
+  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
+"#;
+
+#[derive(Subcommand, Debug, Clone)]
 enum Command {
     /// Parse markdown examples and output a combined .jsonl file
     ParseExample,
@@ -137,7 +169,7 @@ impl Display for Command {
     }
 }
 
-#[derive(Debug, Args)]
+#[derive(Debug, Args, Clone)]
 struct FormatPromptArgs {
     #[clap(long)]
     prompt_format: PromptFormat,
@@ -149,7 +181,7 @@ enum PromptFormat {
     Zeta2,
 }
 
-#[derive(Debug, Args)]
+#[derive(Debug, Args, Clone)]
 struct PredictArgs {
     #[clap(long)]
     provider: PredictionProvider,
@@ -167,7 +199,7 @@ enum PredictionProvider {
     TeacherNonBatching,
 }
 
-#[derive(Debug, Args)]
+#[derive(Debug, Args, Clone)]
 struct SynthesizeArgs {
     /// Repository URL (git@github.com:owner/repo or https://...)
     #[clap(long)]
@@ -200,6 +232,60 @@ impl EpArgs {
     }
 }
 
+async fn load_examples(
+    http_client: Arc<dyn http_client::HttpClient>,
+    args: &EpArgs,
+) -> anyhow::Result<Vec<Example>> {
+    let mut captured_after_timestamps = Vec::new();
+    let mut file_inputs = Vec::new();
+
+    for input in &args.inputs {
+        let input_string = input.to_string_lossy();
+        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
+            captured_after_timestamps.push(timestamp.to_string());
+        } else {
+            file_inputs.push(input.clone());
+        }
+    }
+
+    let mut examples = read_example_files(&file_inputs);
+    let total_steps = examples.len() + captured_after_timestamps.len();
+    Progress::global().set_total_steps(total_steps);
+
+    let remaining_limit_for_snowflake =
+        args.limit.map(|limit| limit.saturating_sub(examples.len()));
+
+    if let Some(0) = remaining_limit_for_snowflake {
+        log::info!(
+            "skipping captured-after inputs because --limit is already satisfied by example files"
+        );
+    } else if !captured_after_timestamps.is_empty() {
+        captured_after_timestamps.sort();
+
+        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
+
+        let mut captured_examples = pull_examples::fetch_captured_examples_after(
+            http_client,
+            &captured_after_timestamps,
+            max_rows_per_timestamp,
+        )
+        .await?;
+        examples.append(&mut captured_examples);
+    }
+
+    crate::example::sort_examples_by_repo_and_rev(&mut examples);
+
+    if let Some(limit) = args.limit {
+        if examples.len() > limit {
+            examples.truncate(limit);
+        }
+    }
+
+    Progress::global().set_total_steps(examples.len() + captured_after_timestamps.len());
+
+    Ok(examples)
+}
+
 fn main() {
     let args = EpArgs::parse();
 
@@ -209,8 +295,8 @@ fn main() {
     }
 
     let output = args.output_path();
-    let command = match args.command {
-        Some(cmd) => cmd,
+    let command = match &args.command {
+        Some(cmd) => cmd.clone(),
         None => {
             EpArgs::command().print_help().unwrap();
             return;
@@ -251,7 +337,6 @@ fn main() {
         _ => {}
     }
 
-    let mut examples = read_examples(&args.inputs);
     let http_client = Arc::new(ReqwestClient::new());
     let app = Application::headless().with_http_client(http_client);
 
@@ -261,12 +346,13 @@ fn main() {
 
         cx.spawn(async move |cx| {
             let result = async {
+                let mut examples = load_examples(app_state.client.http_client(), &args).await?;
+
                 if let Command::Predict(args) = &command {
                     predict::sync_batches(&args.provider).await?;
                 }
 
-                let total_examples = examples.len();
-                Progress::global().set_total_examples(total_examples);
+                let failfast_on_single_example = examples.len() == 1;
 
                 let mut grouped_examples = group_examples_by_repo(&mut examples);
                 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
@@ -347,7 +433,7 @@ fn main() {
 
                                 let msg = format!(
                                     indoc::indoc! {"
-                                        While processing {}:
+                                        While processing \"{}\":
 
                                         {:?}
 
@@ -366,7 +452,7 @@ fn main() {
                                     command,
                                     failed_example_path.display(),
                                 );
-                                if args.failfast || total_examples == 1 {
+                                if args.failfast || failfast_on_single_example {
                                     Progress::global().finalize();
                                     panic!("{}", msg);
                                 } else {

crates/edit_prediction_cli/src/progress.rs 🔗

@@ -19,9 +19,10 @@ struct ProgressInner {
     terminal_width: usize,
     max_example_name_len: usize,
     status_lines_displayed: usize,
-    total_examples: usize,
+    total_steps: usize,
     failed_examples: usize,
     last_line_is_logging: bool,
+    ticker: Option<std::thread::JoinHandle<()>>,
 }
 
 #[derive(Clone)]
@@ -47,6 +48,7 @@ pub enum Step {
     Predict,
     Score,
     Synthesize,
+    PullExamples,
 }
 
 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -64,6 +66,7 @@ impl Step {
             Step::Predict => "Predict",
             Step::Score => "Score",
             Step::Synthesize => "Synthesize",
+            Step::PullExamples => "Pull",
         }
     }
 
@@ -75,6 +78,7 @@ impl Step {
             Step::Predict => "\x1b[32m",
             Step::Score => "\x1b[31m",
             Step::Synthesize => "\x1b[36m",
+            Step::PullExamples => "\x1b[36m",
         }
     }
 }
@@ -84,6 +88,7 @@ static LOGGER: ProgressLogger = ProgressLogger;
 
 const MARGIN: usize = 4;
 const MAX_STATUS_LINES: usize = 10;
+const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300);
 
 impl Progress {
     /// Returns the global Progress instance, initializing it if necessary.
@@ -98,9 +103,10 @@ impl Progress {
                         terminal_width: get_terminal_width(),
                         max_example_name_len: 0,
                         status_lines_displayed: 0,
-                        total_examples: 0,
+                        total_steps: 0,
                         failed_examples: 0,
                         last_line_is_logging: false,
+                        ticker: None,
                     }),
                 });
                 let _ = log::set_logger(&LOGGER);
@@ -110,9 +116,9 @@ impl Progress {
             .clone()
     }
 
-    pub fn set_total_examples(&self, total: usize) {
+    pub fn set_total_steps(&self, total: usize) {
         let mut inner = self.inner.lock().unwrap();
-        inner.total_examples = total;
+        inner.total_steps = total;
     }
 
     pub fn increment_failed(&self) {
@@ -142,7 +148,14 @@ impl Progress {
 
         Self::clear_status_lines(&mut inner);
 
-        inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
+        let max_name_width = inner
+            .terminal_width
+            .saturating_sub(MARGIN * 2)
+            .saturating_div(3)
+            .max(1);
+        inner.max_example_name_len = inner
+            .max_example_name_len
+            .max(example_name.len().min(max_name_width));
         inner.in_progress.insert(
             example_name.to_string(),
             InProgressTask {
@@ -153,6 +166,23 @@ impl Progress {
             },
         );
 
+        if inner.is_tty && inner.ticker.is_none() {
+            let progress = self.clone();
+            inner.ticker = Some(std::thread::spawn(move || {
+                loop {
+                    std::thread::sleep(STATUS_TICK_INTERVAL);
+
+                    let mut inner = progress.inner.lock().unwrap();
+                    if inner.in_progress.is_empty() {
+                        break;
+                    }
+
+                    Progress::clear_status_lines(&mut inner);
+                    Progress::print_status_lines(&mut inner);
+                }
+            }));
+        }
+
         Self::print_status_lines(&mut inner);
 
         StepProgress {
@@ -179,7 +209,9 @@ impl Progress {
 
             Self::clear_status_lines(&mut inner);
             Self::print_logging_closing_divider(&mut inner);
-            Self::print_completed(&inner, inner.completed.last().unwrap());
+            if let Some(last_completed) = inner.completed.last() {
+                Self::print_completed(&inner, last_completed);
+            }
             Self::print_status_lines(&mut inner);
         } else {
             inner.in_progress.insert(example_name.to_string(), task);
@@ -210,6 +242,7 @@ impl Progress {
     fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
         let duration = format_duration(task.duration);
         let name_width = inner.max_example_name_len;
+        let truncated_name = truncate_with_ellipsis(&task.example_name, name_width);
 
         if inner.is_tty {
             let reset = "\x1b[0m";
@@ -233,7 +266,7 @@ impl Progress {
                 "{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}│{reset} {info_part}",
                 color = task.step.color_code(),
                 label = task.step.label(),
-                name = task.example_name,
+                name = truncated_name,
             );
 
             let duration_with_margin = format!("{duration} ");
@@ -255,7 +288,7 @@ impl Progress {
             eprintln!(
                 "{label:>12} {name:<name_width$}{info_part} {duration}",
                 label = task.step.label(),
-                name = task.example_name,
+                name = truncate_with_ellipsis(&task.example_name, name_width),
             );
         }
     }
@@ -283,7 +316,7 @@ impl Progress {
 
         let range_label = format!(
             " {}/{}/{} ",
-            done_count, in_progress_count, inner.total_examples
+            done_count, in_progress_count, inner.total_steps
         );
 
         // Print a divider line with failed count on left, range label on right
@@ -318,10 +351,11 @@ impl Progress {
             let step_label = task.step.label();
             let step_color = task.step.color_code();
             let name_width = inner.max_example_name_len;
+            let truncated_name = truncate_with_ellipsis(name, name_width);
 
             let prefix = format!(
                 "{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}│{reset} {substatus_part}",
-                name = name,
+                name = truncated_name,
             );
 
             let duration_with_margin = format!("{elapsed} ");
@@ -348,6 +382,15 @@ impl Progress {
     }
 
     pub fn finalize(&self) {
+        let ticker = {
+            let mut inner = self.inner.lock().unwrap();
+            inner.ticker.take()
+        };
+
+        if let Some(ticker) = ticker {
+            let _ = ticker.join();
+        }
+
         let mut inner = self.inner.lock().unwrap();
         Self::clear_status_lines(&mut inner);
 

crates/edit_prediction_cli/src/pull_examples.rs 🔗

@@ -0,0 +1,220 @@
+use anyhow::{Context as _, Result};
+use http_client::{AsyncBody, HttpClient, Method, Request};
+use indoc::indoc;
+use serde::Deserialize;
+use serde_json::{Value as JsonValue, json};
+use std::sync::Arc;
+
+use crate::{
+    example::Example,
+    progress::{InfoStyle, Progress, Step},
+};
+use edit_prediction::example_spec::ExampleSpec;
+
+const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
+const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
+
+const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
+
+/// Parse an input token of the form `captured-after:{timestamp}`.
+pub fn parse_captured_after_input(input: &str) -> Option<&str> {
+    input.strip_prefix("captured-after:")
+}
+
+pub async fn fetch_captured_examples_after(
+    http_client: Arc<dyn HttpClient>,
+    after_timestamps: &[String],
+    max_rows_per_timestamp: usize,
+) -> Result<Vec<Example>> {
+    if after_timestamps.is_empty() {
+        return Ok(Vec::new());
+    }
+
+    let progress = Progress::global();
+
+    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
+        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
+    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
+        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
+    )?;
+    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
+
+    let mut all_examples = Vec::new();
+
+    for after_date in after_timestamps.iter() {
+        let step_progress_name = format!(">{after_date}");
+        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
+        step_progress.set_substatus("querying");
+
+        let statement = indoc! {r#"
+            SELECT
+                event_properties:example AS example
+            FROM events
+            WHERE event_type = ?
+                AND time > TRY_TO_TIMESTAMP_NTZ(?)
+            ORDER BY time ASC
+            LIMIT ?
+        "#};
+
+        let request = json!({
+            "statement": statement,
+            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+            "database": "EVENTS",
+            "schema": "PUBLIC",
+            "warehouse": "DBT",
+            "role": role,
+            "bindings": {
+                "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
+                "2": { "type": "TEXT", "value": after_date },
+                "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
+            }
+        });
+
+        let response = run_sql(http_client.clone(), &base_url, &token, &request).await?;
+
+        step_progress.set_info(format!("{} rows", response.data.len()), InfoStyle::Normal);
+        step_progress.set_substatus("parsing");
+
+        all_examples.extend(examples_from_response(&response)?);
+
+        step_progress.set_substatus("done");
+    }
+
+    Ok(all_examples)
+}
+
+#[derive(Debug, Clone, Deserialize)]
+struct SnowflakeStatementResponse {
+    #[serde(default)]
+    data: Vec<Vec<JsonValue>>,
+    #[serde(default)]
+    result_set_meta_data: Option<SnowflakeResultSetMetaData>,
+    #[serde(default)]
+    code: Option<String>,
+    #[serde(default)]
+    message: Option<String>,
+}
+
+#[derive(Debug, Clone, Deserialize)]
+struct SnowflakeResultSetMetaData {
+    #[serde(default, rename = "rowType")]
+    row_type: Vec<SnowflakeColumnMeta>,
+}
+
+#[derive(Debug, Clone, Deserialize)]
+struct SnowflakeColumnMeta {
+    #[serde(default)]
+    name: String,
+}
+
+fn examples_from_response(
+    response: &SnowflakeStatementResponse,
+) -> Result<impl Iterator<Item = Example>> {
+    if let Some(code) = &response.code {
+        if code != SNOWFLAKE_SUCCESS_CODE {
+            anyhow::bail!(
+                "snowflake sql api returned error code={code} message={}",
+                response.message.as_deref().unwrap_or("<no message>")
+            );
+        }
+    }
+
+    let example_index = response
+        .result_set_meta_data
+        .as_ref()
+        .and_then(|m| {
+            m.row_type.iter().enumerate().find_map(|(index, col)| {
+                if col.name.eq_ignore_ascii_case("example") {
+                    Some(index)
+                } else {
+                    None
+                }
+            })
+        })
+        .unwrap_or(0);
+
+    let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
+        let Some(example_value) = data_row.get(example_index) else {
+            return None;
+        };
+        if example_value.is_null() {
+            return None;
+        }
+
+        let parse_result = match example_value {
+            JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
+            _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
+        };
+
+        match parse_result {
+            Ok(spec) => Some(Example {
+                spec,
+                buffer: None,
+                context: None,
+                prompt: None,
+                predictions: Vec::new(),
+                score: Vec::new(),
+                state: None,
+            }),
+            Err(error) => {
+                let raw_json = serde_json::to_string_pretty(example_value)
+                    .unwrap_or_else(|_| "<failed to serialize json>".to_string());
+                log::error!(
+                    "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
+                );
+                None
+            }
+        }
+    });
+
+    Ok(iter)
+}
+
+async fn run_sql(
+    http_client: Arc<dyn HttpClient>,
+    base_url: &str,
+    token: &str,
+    request: &serde_json::Value,
+) -> Result<SnowflakeStatementResponse> {
+    let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
+
+    let request_body =
+        serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
+
+    let http_request = Request::builder()
+        .method(Method::POST)
+        .uri(url.as_str())
+        .header("Authorization", format!("Bearer {token}"))
+        .header(
+            "X-Snowflake-Authorization-Token-Type",
+            "PROGRAMMATIC_ACCESS_TOKEN",
+        )
+        .header("Content-Type", "application/json")
+        .header("Accept", "application/json")
+        .body(AsyncBody::from(request_body.clone()))?;
+
+    let response = http_client
+        .send(http_request)
+        .await
+        .context("failed to send request to Snowflake SQL API")?;
+
+    let status = response.status();
+    let body_bytes = {
+        use futures::AsyncReadExt as _;
+
+        let mut body = response.into_body();
+        let mut bytes = Vec::new();
+        body.read_to_end(&mut bytes)
+            .await
+            .context("failed to read Snowflake SQL API response body")?;
+        bytes
+    };
+
+    if !status.is_success() {
+        let body_text = String::from_utf8_lossy(&body_bytes);
+        anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
+    }
+
+    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
+        .context("failed to parse Snowflake SQL API response JSON")
+}

crates/edit_prediction_cli/src/split_commit.rs 🔗

@@ -16,7 +16,7 @@ use std::fs;
 use std::io::{self, Read};
 
 /// `ep split-commit` CLI args.
-#[derive(Debug, Args)]
+#[derive(Debug, Args, Clone)]
 pub struct SplitCommitArgs {
     /// Path to the commit file (use "-" for stdin)
     #[arg(long, short = 'c')]

crates/edit_prediction_cli/src/synthesize.rs 🔗

@@ -108,7 +108,7 @@ pub async fn run_synthesize(config: SynthesizeConfig) -> Result<()> {
     std::os::windows::fs::symlink_dir(&*FAILED_EXAMPLES_DIR, &*LATEST_FAILED_EXAMPLES_DIR)?;
 
     let progress = Progress::global();
-    progress.set_total_examples(config.count);
+    progress.set_total_steps(config.count);
 
     let clone_progress = progress.start(Step::Synthesize, "clone");
     let repo_path = ensure_repo_cloned(&config.repo_url).await?;