Detailed changes
@@ -125,17 +125,9 @@ impl Example {
}
}
-pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
+pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
let mut examples = Vec::new();
- let stdin_path: PathBuf = PathBuf::from("-");
-
- let inputs = if inputs.is_empty() {
- &[stdin_path]
- } else {
- inputs
- };
-
for path in inputs {
let is_stdin = path.as_path() == Path::new("-");
let content = if is_stdin {
@@ -201,7 +193,6 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
}
}
- sort_examples_by_repo_and_rev(&mut examples);
examples
}
@@ -9,12 +9,12 @@ mod metrics;
mod paths;
mod predict;
mod progress;
+mod pull_examples;
mod reorder_patch;
mod retrieve_context;
mod score;
mod split_commit;
mod synthesize;
-
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
use edit_prediction::EditPredictionStore;
use gpui::Application;
@@ -24,7 +24,7 @@ use std::fmt::Display;
use std::{path::PathBuf, sync::Arc};
use crate::distill::run_distill;
-use crate::example::{group_examples_by_repo, read_examples, write_examples};
+use crate::example::{Example, group_examples_by_repo, read_example_files, write_examples};
use crate::format_prompt::run_format_prompt;
use crate::load_project::run_load_project;
use crate::paths::FAILED_EXAMPLES_DIR;
@@ -42,9 +42,11 @@ struct EpArgs {
printenv: bool,
#[clap(long, default_value_t = 10, global = true)]
max_parallelism: usize,
+ #[clap(long, global = true)]
+ limit: Option<usize>,
#[command(subcommand)]
command: Option<Command>,
- #[clap(global = true)]
+ #[clap(global = true, help = INPUTS_HELP)]
inputs: Vec<PathBuf>,
#[arg(long, short, global = true)]
output: Option<PathBuf>,
@@ -54,7 +56,37 @@ struct EpArgs {
failfast: bool,
}
-#[derive(Subcommand, Debug)]
+const INPUTS_HELP: &str = r#"
+Inputs can be file paths or special specifiers:
+
+ path
+ Path to an example(s) file (.md, .json, or .jsonl)
+
+ captured-after:{timestamp}
+ Fetch captured examples from Snowflake after the given RFC3339 timestamp.
+
+ You can specify this multiple times and mix it with file inputs.
+
+ Required environment variables to connect to Snowflake:
+ EP_SNOWFLAKE_API_KEY
+ EP_SNOWFLAKE_BASE_URL
+
+ Optional:
+ EP_SNOWFLAKE_ROLE
+
+Examples:
+
+ # Predict from a file
+ ep predict examples.jsonl
+
+ # Predict from captured examples after a timestamp
+ ep predict captured-after:2025-01-01T00:00:00Z
+
+ # Mix file inputs and captured-after in the same invocation
+ ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
+"#;
+
+#[derive(Subcommand, Debug, Clone)]
enum Command {
/// Parse markdown examples and output a combined .jsonl file
ParseExample,
@@ -137,7 +169,7 @@ impl Display for Command {
}
}
-#[derive(Debug, Args)]
+#[derive(Debug, Args, Clone)]
struct FormatPromptArgs {
#[clap(long)]
prompt_format: PromptFormat,
@@ -149,7 +181,7 @@ enum PromptFormat {
Zeta2,
}
-#[derive(Debug, Args)]
+#[derive(Debug, Args, Clone)]
struct PredictArgs {
#[clap(long)]
provider: PredictionProvider,
@@ -167,7 +199,7 @@ enum PredictionProvider {
TeacherNonBatching,
}
-#[derive(Debug, Args)]
+#[derive(Debug, Args, Clone)]
struct SynthesizeArgs {
/// Repository URL (git@github.com:owner/repo or https://...)
#[clap(long)]
@@ -200,6 +232,60 @@ impl EpArgs {
}
}
+async fn load_examples(
+ http_client: Arc<dyn http_client::HttpClient>,
+ args: &EpArgs,
+) -> anyhow::Result<Vec<Example>> {
+ let mut captured_after_timestamps = Vec::new();
+ let mut file_inputs = Vec::new();
+
+ for input in &args.inputs {
+ let input_string = input.to_string_lossy();
+ if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
+ captured_after_timestamps.push(timestamp.to_string());
+ } else {
+ file_inputs.push(input.clone());
+ }
+ }
+
+ let mut examples = read_example_files(&file_inputs);
+ let total_steps = examples.len() + captured_after_timestamps.len();
+ Progress::global().set_total_steps(total_steps);
+
+ let remaining_limit_for_snowflake =
+ args.limit.map(|limit| limit.saturating_sub(examples.len()));
+
+ if let Some(0) = remaining_limit_for_snowflake {
+ log::info!(
+ "skipping captured-after inputs because --limit is already satisfied by example files"
+ );
+ } else if !captured_after_timestamps.is_empty() {
+ captured_after_timestamps.sort();
+
+ let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
+
+ let mut captured_examples = pull_examples::fetch_captured_examples_after(
+ http_client,
+ &captured_after_timestamps,
+ max_rows_per_timestamp,
+ )
+ .await?;
+ examples.append(&mut captured_examples);
+ }
+
+ crate::example::sort_examples_by_repo_and_rev(&mut examples);
+
+ if let Some(limit) = args.limit {
+ if examples.len() > limit {
+ examples.truncate(limit);
+ }
+ }
+
+ Progress::global().set_total_steps(examples.len() + captured_after_timestamps.len());
+
+ Ok(examples)
+}
+
fn main() {
let args = EpArgs::parse();
@@ -209,8 +295,8 @@ fn main() {
}
let output = args.output_path();
- let command = match args.command {
- Some(cmd) => cmd,
+ let command = match &args.command {
+ Some(cmd) => cmd.clone(),
None => {
EpArgs::command().print_help().unwrap();
return;
@@ -251,7 +337,6 @@ fn main() {
_ => {}
}
- let mut examples = read_examples(&args.inputs);
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client);
@@ -261,12 +346,13 @@ fn main() {
cx.spawn(async move |cx| {
let result = async {
+ let mut examples = load_examples(app_state.client.http_client(), &args).await?;
+
if let Command::Predict(args) = &command {
predict::sync_batches(&args.provider).await?;
}
- let total_examples = examples.len();
- Progress::global().set_total_examples(total_examples);
+ let failfast_on_single_example = examples.len() == 1;
let mut grouped_examples = group_examples_by_repo(&mut examples);
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
@@ -347,7 +433,7 @@ fn main() {
let msg = format!(
indoc::indoc! {"
- While processing {}:
+ While processing \"{}\":
{:?}
@@ -366,7 +452,7 @@ fn main() {
command,
failed_example_path.display(),
);
- if args.failfast || total_examples == 1 {
+ if args.failfast || failfast_on_single_example {
Progress::global().finalize();
panic!("{}", msg);
} else {
@@ -19,9 +19,10 @@ struct ProgressInner {
terminal_width: usize,
max_example_name_len: usize,
status_lines_displayed: usize,
- total_examples: usize,
+ total_steps: usize,
failed_examples: usize,
last_line_is_logging: bool,
+ ticker: Option<std::thread::JoinHandle<()>>,
}
#[derive(Clone)]
@@ -47,6 +48,7 @@ pub enum Step {
Predict,
Score,
Synthesize,
+ PullExamples,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -64,6 +66,7 @@ impl Step {
Step::Predict => "Predict",
Step::Score => "Score",
Step::Synthesize => "Synthesize",
+ Step::PullExamples => "Pull",
}
}
@@ -75,6 +78,7 @@ impl Step {
Step::Predict => "\x1b[32m",
Step::Score => "\x1b[31m",
Step::Synthesize => "\x1b[36m",
+ Step::PullExamples => "\x1b[36m",
}
}
}
@@ -84,6 +88,7 @@ static LOGGER: ProgressLogger = ProgressLogger;
const MARGIN: usize = 4;
const MAX_STATUS_LINES: usize = 10;
+const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300);
impl Progress {
/// Returns the global Progress instance, initializing it if necessary.
@@ -98,9 +103,10 @@ impl Progress {
terminal_width: get_terminal_width(),
max_example_name_len: 0,
status_lines_displayed: 0,
- total_examples: 0,
+ total_steps: 0,
failed_examples: 0,
last_line_is_logging: false,
+ ticker: None,
}),
});
let _ = log::set_logger(&LOGGER);
@@ -110,9 +116,9 @@ impl Progress {
.clone()
}
- pub fn set_total_examples(&self, total: usize) {
+ pub fn set_total_steps(&self, total: usize) {
let mut inner = self.inner.lock().unwrap();
- inner.total_examples = total;
+ inner.total_steps = total;
}
pub fn increment_failed(&self) {
@@ -142,7 +148,14 @@ impl Progress {
Self::clear_status_lines(&mut inner);
- inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
+ let max_name_width = inner
+ .terminal_width
+ .saturating_sub(MARGIN * 2)
+ .saturating_div(3)
+ .max(1);
+ inner.max_example_name_len = inner
+ .max_example_name_len
+ .max(example_name.len().min(max_name_width));
inner.in_progress.insert(
example_name.to_string(),
InProgressTask {
@@ -153,6 +166,23 @@ impl Progress {
},
);
+ if inner.is_tty && inner.ticker.is_none() {
+ let progress = self.clone();
+ inner.ticker = Some(std::thread::spawn(move || {
+ loop {
+ std::thread::sleep(STATUS_TICK_INTERVAL);
+
+ let mut inner = progress.inner.lock().unwrap();
+ if inner.in_progress.is_empty() {
+ break;
+ }
+
+ Progress::clear_status_lines(&mut inner);
+ Progress::print_status_lines(&mut inner);
+ }
+ }));
+ }
+
Self::print_status_lines(&mut inner);
StepProgress {
@@ -179,7 +209,9 @@ impl Progress {
Self::clear_status_lines(&mut inner);
Self::print_logging_closing_divider(&mut inner);
- Self::print_completed(&inner, inner.completed.last().unwrap());
+ if let Some(last_completed) = inner.completed.last() {
+ Self::print_completed(&inner, last_completed);
+ }
Self::print_status_lines(&mut inner);
} else {
inner.in_progress.insert(example_name.to_string(), task);
@@ -210,6 +242,7 @@ impl Progress {
fn print_completed(inner: &ProgressInner, task: &CompletedTask) {
let duration = format_duration(task.duration);
let name_width = inner.max_example_name_len;
+ let truncated_name = truncate_with_ellipsis(&task.example_name, name_width);
if inner.is_tty {
let reset = "\x1b[0m";
@@ -233,7 +266,7 @@ impl Progress {
"{bold}{color}{label:>12}{reset} {name:<name_width$} {dim}│{reset} {info_part}",
color = task.step.color_code(),
label = task.step.label(),
- name = task.example_name,
+ name = truncated_name,
);
let duration_with_margin = format!("{duration} ");
@@ -255,7 +288,7 @@ impl Progress {
eprintln!(
"{label:>12} {name:<name_width$}{info_part} {duration}",
label = task.step.label(),
- name = task.example_name,
+ name = truncate_with_ellipsis(&task.example_name, name_width),
);
}
}
@@ -283,7 +316,7 @@ impl Progress {
let range_label = format!(
" {}/{}/{} ",
- done_count, in_progress_count, inner.total_examples
+ done_count, in_progress_count, inner.total_steps
);
// Print a divider line with failed count on left, range label on right
@@ -318,10 +351,11 @@ impl Progress {
let step_label = task.step.label();
let step_color = task.step.color_code();
let name_width = inner.max_example_name_len;
+ let truncated_name = truncate_with_ellipsis(name, name_width);
let prefix = format!(
"{bold}{step_color}{step_label:>12}{reset} {name:<name_width$} {dim}│{reset} {substatus_part}",
- name = name,
+ name = truncated_name,
);
let duration_with_margin = format!("{elapsed} ");
@@ -348,6 +382,15 @@ impl Progress {
}
pub fn finalize(&self) {
+ let ticker = {
+ let mut inner = self.inner.lock().unwrap();
+ inner.ticker.take()
+ };
+
+ if let Some(ticker) = ticker {
+ let _ = ticker.join();
+ }
+
let mut inner = self.inner.lock().unwrap();
Self::clear_status_lines(&mut inner);
@@ -0,0 +1,220 @@
+use anyhow::{Context as _, Result};
+use http_client::{AsyncBody, HttpClient, Method, Request};
+use indoc::indoc;
+use serde::Deserialize;
+use serde_json::{Value as JsonValue, json};
+use std::sync::Arc;
+
+use crate::{
+ example::Example,
+ progress::{InfoStyle, Progress, Step},
+};
+use edit_prediction::example_spec::ExampleSpec;
+
+const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
+const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
+
+const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
+
+/// Parse an input token of the form `captured-after:{timestamp}`.
+pub fn parse_captured_after_input(input: &str) -> Option<&str> {
+ input.strip_prefix("captured-after:")
+}
+
+pub async fn fetch_captured_examples_after(
+ http_client: Arc<dyn HttpClient>,
+ after_timestamps: &[String],
+ max_rows_per_timestamp: usize,
+) -> Result<Vec<Example>> {
+ if after_timestamps.is_empty() {
+ return Ok(Vec::new());
+ }
+
+ let progress = Progress::global();
+
+ let token = std::env::var("EP_SNOWFLAKE_API_KEY")
+ .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
+ let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
+ "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
+ )?;
+ let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
+
+ let mut all_examples = Vec::new();
+
+ for after_date in after_timestamps.iter() {
+ let step_progress_name = format!(">{after_date}");
+ let step_progress = progress.start(Step::PullExamples, &step_progress_name);
+ step_progress.set_substatus("querying");
+
+ let statement = indoc! {r#"
+ SELECT
+ event_properties:example AS example
+ FROM events
+ WHERE event_type = ?
+ AND time > TRY_TO_TIMESTAMP_NTZ(?)
+ ORDER BY time ASC
+ LIMIT ?
+ "#};
+
+ let request = json!({
+ "statement": statement,
+ "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+ "database": "EVENTS",
+ "schema": "PUBLIC",
+ "warehouse": "DBT",
+ "role": role,
+ "bindings": {
+ "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
+ "2": { "type": "TEXT", "value": after_date },
+ "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
+ }
+ });
+
+ let response = run_sql(http_client.clone(), &base_url, &token, &request).await?;
+
+ step_progress.set_info(format!("{} rows", response.data.len()), InfoStyle::Normal);
+ step_progress.set_substatus("parsing");
+
+ all_examples.extend(examples_from_response(&response)?);
+
+ step_progress.set_substatus("done");
+ }
+
+ Ok(all_examples)
+}
+
+#[derive(Debug, Clone, Deserialize)]
+struct SnowflakeStatementResponse {
+ #[serde(default)]
+ data: Vec<Vec<JsonValue>>,
+ #[serde(default)]
+ result_set_meta_data: Option<SnowflakeResultSetMetaData>,
+ #[serde(default)]
+ code: Option<String>,
+ #[serde(default)]
+ message: Option<String>,
+}
+
+#[derive(Debug, Clone, Deserialize)]
+struct SnowflakeResultSetMetaData {
+ #[serde(default, rename = "rowType")]
+ row_type: Vec<SnowflakeColumnMeta>,
+}
+
+#[derive(Debug, Clone, Deserialize)]
+struct SnowflakeColumnMeta {
+ #[serde(default)]
+ name: String,
+}
+
+fn examples_from_response(
+ response: &SnowflakeStatementResponse,
+) -> Result<impl Iterator<Item = Example>> {
+ if let Some(code) = &response.code {
+ if code != SNOWFLAKE_SUCCESS_CODE {
+ anyhow::bail!(
+ "snowflake sql api returned error code={code} message={}",
+ response.message.as_deref().unwrap_or("<no message>")
+ );
+ }
+ }
+
+ let example_index = response
+ .result_set_meta_data
+ .as_ref()
+ .and_then(|m| {
+ m.row_type.iter().enumerate().find_map(|(index, col)| {
+ if col.name.eq_ignore_ascii_case("example") {
+ Some(index)
+ } else {
+ None
+ }
+ })
+ })
+ .unwrap_or(0);
+
+ let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
+ let Some(example_value) = data_row.get(example_index) else {
+ return None;
+ };
+ if example_value.is_null() {
+ return None;
+ }
+
+ let parse_result = match example_value {
+ JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
+ _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
+ };
+
+ match parse_result {
+ Ok(spec) => Some(Example {
+ spec,
+ buffer: None,
+ context: None,
+ prompt: None,
+ predictions: Vec::new(),
+ score: Vec::new(),
+ state: None,
+ }),
+ Err(error) => {
+ let raw_json = serde_json::to_string_pretty(example_value)
+ .unwrap_or_else(|_| "<failed to serialize json>".to_string());
+ log::error!(
+ "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
+ );
+ None
+ }
+ }
+ });
+
+ Ok(iter)
+}
+
+async fn run_sql(
+ http_client: Arc<dyn HttpClient>,
+ base_url: &str,
+ token: &str,
+ request: &serde_json::Value,
+) -> Result<SnowflakeStatementResponse> {
+ let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
+
+ let request_body =
+ serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
+
+ let http_request = Request::builder()
+ .method(Method::POST)
+ .uri(url.as_str())
+ .header("Authorization", format!("Bearer {token}"))
+ .header(
+ "X-Snowflake-Authorization-Token-Type",
+ "PROGRAMMATIC_ACCESS_TOKEN",
+ )
+ .header("Content-Type", "application/json")
+ .header("Accept", "application/json")
+ .body(AsyncBody::from(request_body.clone()))?;
+
+ let response = http_client
+ .send(http_request)
+ .await
+ .context("failed to send request to Snowflake SQL API")?;
+
+ let status = response.status();
+ let body_bytes = {
+ use futures::AsyncReadExt as _;
+
+ let mut body = response.into_body();
+ let mut bytes = Vec::new();
+ body.read_to_end(&mut bytes)
+ .await
+ .context("failed to read Snowflake SQL API response body")?;
+ bytes
+ };
+
+ if !status.is_success() {
+ let body_text = String::from_utf8_lossy(&body_bytes);
+ anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
+ }
+
+ serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
+ .context("failed to parse Snowflake SQL API response JSON")
+}
@@ -16,7 +16,7 @@ use std::fs;
use std::io::{self, Read};
/// `ep split-commit` CLI args.
-#[derive(Debug, Args)]
+#[derive(Debug, Args, Clone)]
pub struct SplitCommitArgs {
/// Path to the commit file (use "-" for stdin)
#[arg(long, short = 'c')]
@@ -108,7 +108,7 @@ pub async fn run_synthesize(config: SynthesizeConfig) -> Result<()> {
std::os::windows::fs::symlink_dir(&*FAILED_EXAMPLES_DIR, &*LATEST_FAILED_EXAMPLES_DIR)?;
let progress = Progress::global();
- progress.set_total_examples(config.count);
+ progress.set_total_steps(config.count);
let clone_progress = progress.start(Step::Synthesize, "clone");
let repo_path = ensure_repo_cloned(&config.repo_url).await?;