diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 63a53b0d7dc667b05171d486e078617187f24fe6..441dcd3a4378159b0a65f02c9c194d8eb5fc6ec8 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -125,17 +125,9 @@ impl Example { } } -pub fn read_examples(inputs: &[PathBuf]) -> Vec { +pub fn read_example_files(inputs: &[PathBuf]) -> Vec { let mut examples = Vec::new(); - let stdin_path: PathBuf = PathBuf::from("-"); - - let inputs = if inputs.is_empty() { - &[stdin_path] - } else { - inputs - }; - for path in inputs { let is_stdin = path.as_path() == Path::new("-"); let content = if is_stdin { @@ -201,7 +193,6 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec { } } - sort_examples_by_repo_and_rev(&mut examples); examples } diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index b54ae89409adcc496f56d503e994df3132e76dc7..b7f843620ca353bf2f4743c3e26098807274a5b8 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -9,12 +9,12 @@ mod metrics; mod paths; mod predict; mod progress; +mod pull_examples; mod reorder_patch; mod retrieve_context; mod score; mod split_commit; mod synthesize; - use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum}; use edit_prediction::EditPredictionStore; use gpui::Application; @@ -24,7 +24,7 @@ use std::fmt::Display; use std::{path::PathBuf, sync::Arc}; use crate::distill::run_distill; -use crate::example::{group_examples_by_repo, read_examples, write_examples}; +use crate::example::{Example, group_examples_by_repo, read_example_files, write_examples}; use crate::format_prompt::run_format_prompt; use crate::load_project::run_load_project; use crate::paths::FAILED_EXAMPLES_DIR; @@ -42,9 +42,11 @@ struct EpArgs { printenv: bool, #[clap(long, default_value_t = 10, global = true)] max_parallelism: usize, + #[clap(long, global = true)] + limit: Option, #[command(subcommand)] command: Option, - #[clap(global = true)] + #[clap(global = true, help = INPUTS_HELP)] inputs: Vec, #[arg(long, short, global = true)] output: Option, @@ -54,7 +56,37 @@ struct EpArgs { failfast: bool, } -#[derive(Subcommand, Debug)] +const INPUTS_HELP: &str = r#" +Inputs can be file paths or special specifiers: + + path + Path to an example(s) file (.md, .json, or .jsonl) + + captured-after:{timestamp} + Fetch captured examples from Snowflake after the given RFC3339 timestamp. + + You can specify this multiple times and mix it with file inputs. + + Required environment variables to connect to Snowflake: + EP_SNOWFLAKE_API_KEY + EP_SNOWFLAKE_BASE_URL + + Optional: + EP_SNOWFLAKE_ROLE + +Examples: + + # Predict from a file + ep predict examples.jsonl + + # Predict from captured examples after a timestamp + ep predict captured-after:2025-01-01T00:00:00Z + + # Mix file inputs and captured-after in the same invocation + ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z +"#; + +#[derive(Subcommand, Debug, Clone)] enum Command { /// Parse markdown examples and output a combined .jsonl file ParseExample, @@ -137,7 +169,7 @@ impl Display for Command { } } -#[derive(Debug, Args)] +#[derive(Debug, Args, Clone)] struct FormatPromptArgs { #[clap(long)] prompt_format: PromptFormat, @@ -149,7 +181,7 @@ enum PromptFormat { Zeta2, } -#[derive(Debug, Args)] +#[derive(Debug, Args, Clone)] struct PredictArgs { #[clap(long)] provider: PredictionProvider, @@ -167,7 +199,7 @@ enum PredictionProvider { TeacherNonBatching, } -#[derive(Debug, Args)] +#[derive(Debug, Args, Clone)] struct SynthesizeArgs { /// Repository URL (git@github.com:owner/repo or https://...) #[clap(long)] @@ -200,6 +232,60 @@ impl EpArgs { } } +async fn load_examples( + http_client: Arc, + args: &EpArgs, +) -> anyhow::Result> { + let mut captured_after_timestamps = Vec::new(); + let mut file_inputs = Vec::new(); + + for input in &args.inputs { + let input_string = input.to_string_lossy(); + if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) { + captured_after_timestamps.push(timestamp.to_string()); + } else { + file_inputs.push(input.clone()); + } + } + + let mut examples = read_example_files(&file_inputs); + let total_steps = examples.len() + captured_after_timestamps.len(); + Progress::global().set_total_steps(total_steps); + + let remaining_limit_for_snowflake = + args.limit.map(|limit| limit.saturating_sub(examples.len())); + + if let Some(0) = remaining_limit_for_snowflake { + log::info!( + "skipping captured-after inputs because --limit is already satisfied by example files" + ); + } else if !captured_after_timestamps.is_empty() { + captured_after_timestamps.sort(); + + let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000); + + let mut captured_examples = pull_examples::fetch_captured_examples_after( + http_client, + &captured_after_timestamps, + max_rows_per_timestamp, + ) + .await?; + examples.append(&mut captured_examples); + } + + crate::example::sort_examples_by_repo_and_rev(&mut examples); + + if let Some(limit) = args.limit { + if examples.len() > limit { + examples.truncate(limit); + } + } + + Progress::global().set_total_steps(examples.len() + captured_after_timestamps.len()); + + Ok(examples) +} + fn main() { let args = EpArgs::parse(); @@ -209,8 +295,8 @@ fn main() { } let output = args.output_path(); - let command = match args.command { - Some(cmd) => cmd, + let command = match &args.command { + Some(cmd) => cmd.clone(), None => { EpArgs::command().print_help().unwrap(); return; @@ -251,7 +337,6 @@ fn main() { _ => {} } - let mut examples = read_examples(&args.inputs); let http_client = Arc::new(ReqwestClient::new()); let app = Application::headless().with_http_client(http_client); @@ -261,12 +346,13 @@ fn main() { cx.spawn(async move |cx| { let result = async { + let mut examples = load_examples(app_state.client.http_client(), &args).await?; + if let Command::Predict(args) = &command { predict::sync_batches(&args.provider).await?; } - let total_examples = examples.len(); - Progress::global().set_total_examples(total_examples); + let failfast_on_single_example = examples.len() == 1; let mut grouped_examples = group_examples_by_repo(&mut examples); let example_batches = grouped_examples.chunks_mut(args.max_parallelism); @@ -347,7 +433,7 @@ fn main() { let msg = format!( indoc::indoc! {" - While processing {}: + While processing \"{}\": {:?} @@ -366,7 +452,7 @@ fn main() { command, failed_example_path.display(), ); - if args.failfast || total_examples == 1 { + if args.failfast || failfast_on_single_example { Progress::global().finalize(); panic!("{}", msg); } else { diff --git a/crates/edit_prediction_cli/src/progress.rs b/crates/edit_prediction_cli/src/progress.rs index c6157b1de9c8f09b1442ca9f3badf02c139b2b01..c2878e45812ea503e2817f744f9d5993359914b7 100644 --- a/crates/edit_prediction_cli/src/progress.rs +++ b/crates/edit_prediction_cli/src/progress.rs @@ -19,9 +19,10 @@ struct ProgressInner { terminal_width: usize, max_example_name_len: usize, status_lines_displayed: usize, - total_examples: usize, + total_steps: usize, failed_examples: usize, last_line_is_logging: bool, + ticker: Option>, } #[derive(Clone)] @@ -47,6 +48,7 @@ pub enum Step { Predict, Score, Synthesize, + PullExamples, } #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -64,6 +66,7 @@ impl Step { Step::Predict => "Predict", Step::Score => "Score", Step::Synthesize => "Synthesize", + Step::PullExamples => "Pull", } } @@ -75,6 +78,7 @@ impl Step { Step::Predict => "\x1b[32m", Step::Score => "\x1b[31m", Step::Synthesize => "\x1b[36m", + Step::PullExamples => "\x1b[36m", } } } @@ -84,6 +88,7 @@ static LOGGER: ProgressLogger = ProgressLogger; const MARGIN: usize = 4; const MAX_STATUS_LINES: usize = 10; +const STATUS_TICK_INTERVAL: Duration = Duration::from_millis(300); impl Progress { /// Returns the global Progress instance, initializing it if necessary. @@ -98,9 +103,10 @@ impl Progress { terminal_width: get_terminal_width(), max_example_name_len: 0, status_lines_displayed: 0, - total_examples: 0, + total_steps: 0, failed_examples: 0, last_line_is_logging: false, + ticker: None, }), }); let _ = log::set_logger(&LOGGER); @@ -110,9 +116,9 @@ impl Progress { .clone() } - pub fn set_total_examples(&self, total: usize) { + pub fn set_total_steps(&self, total: usize) { let mut inner = self.inner.lock().unwrap(); - inner.total_examples = total; + inner.total_steps = total; } pub fn increment_failed(&self) { @@ -142,7 +148,14 @@ impl Progress { Self::clear_status_lines(&mut inner); - inner.max_example_name_len = inner.max_example_name_len.max(example_name.len()); + let max_name_width = inner + .terminal_width + .saturating_sub(MARGIN * 2) + .saturating_div(3) + .max(1); + inner.max_example_name_len = inner + .max_example_name_len + .max(example_name.len().min(max_name_width)); inner.in_progress.insert( example_name.to_string(), InProgressTask { @@ -153,6 +166,23 @@ impl Progress { }, ); + if inner.is_tty && inner.ticker.is_none() { + let progress = self.clone(); + inner.ticker = Some(std::thread::spawn(move || { + loop { + std::thread::sleep(STATUS_TICK_INTERVAL); + + let mut inner = progress.inner.lock().unwrap(); + if inner.in_progress.is_empty() { + break; + } + + Progress::clear_status_lines(&mut inner); + Progress::print_status_lines(&mut inner); + } + })); + } + Self::print_status_lines(&mut inner); StepProgress { @@ -179,7 +209,9 @@ impl Progress { Self::clear_status_lines(&mut inner); Self::print_logging_closing_divider(&mut inner); - Self::print_completed(&inner, inner.completed.last().unwrap()); + if let Some(last_completed) = inner.completed.last() { + Self::print_completed(&inner, last_completed); + } Self::print_status_lines(&mut inner); } else { inner.in_progress.insert(example_name.to_string(), task); @@ -210,6 +242,7 @@ impl Progress { fn print_completed(inner: &ProgressInner, task: &CompletedTask) { let duration = format_duration(task.duration); let name_width = inner.max_example_name_len; + let truncated_name = truncate_with_ellipsis(&task.example_name, name_width); if inner.is_tty { let reset = "\x1b[0m"; @@ -233,7 +266,7 @@ impl Progress { "{bold}{color}{label:>12}{reset} {name:12} {name:12}{reset} {name: Option<&str> { + input.strip_prefix("captured-after:") +} + +pub async fn fetch_captured_examples_after( + http_client: Arc, + after_timestamps: &[String], + max_rows_per_timestamp: usize, +) -> Result> { + if after_timestamps.is_empty() { + return Ok(Vec::new()); + } + + let progress = Progress::global(); + + let token = std::env::var("EP_SNOWFLAKE_API_KEY") + .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?; + let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context( + "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://.snowflakecomputing.com)", + )?; + let role = std::env::var("EP_SNOWFLAKE_ROLE").ok(); + + let mut all_examples = Vec::new(); + + for after_date in after_timestamps.iter() { + let step_progress_name = format!(">{after_date}"); + let step_progress = progress.start(Step::PullExamples, &step_progress_name); + step_progress.set_substatus("querying"); + + let statement = indoc! {r#" + SELECT + event_properties:example AS example + FROM events + WHERE event_type = ? + AND time > TRY_TO_TIMESTAMP_NTZ(?) + ORDER BY time ASC + LIMIT ? + "#}; + + let request = json!({ + "statement": statement, + "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS, + "database": "EVENTS", + "schema": "PUBLIC", + "warehouse": "DBT", + "role": role, + "bindings": { + "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT }, + "2": { "type": "TEXT", "value": after_date }, + "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() } + } + }); + + let response = run_sql(http_client.clone(), &base_url, &token, &request).await?; + + step_progress.set_info(format!("{} rows", response.data.len()), InfoStyle::Normal); + step_progress.set_substatus("parsing"); + + all_examples.extend(examples_from_response(&response)?); + + step_progress.set_substatus("done"); + } + + Ok(all_examples) +} + +#[derive(Debug, Clone, Deserialize)] +struct SnowflakeStatementResponse { + #[serde(default)] + data: Vec>, + #[serde(default)] + result_set_meta_data: Option, + #[serde(default)] + code: Option, + #[serde(default)] + message: Option, +} + +#[derive(Debug, Clone, Deserialize)] +struct SnowflakeResultSetMetaData { + #[serde(default, rename = "rowType")] + row_type: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +struct SnowflakeColumnMeta { + #[serde(default)] + name: String, +} + +fn examples_from_response( + response: &SnowflakeStatementResponse, +) -> Result> { + if let Some(code) = &response.code { + if code != SNOWFLAKE_SUCCESS_CODE { + anyhow::bail!( + "snowflake sql api returned error code={code} message={}", + response.message.as_deref().unwrap_or("") + ); + } + } + + let example_index = response + .result_set_meta_data + .as_ref() + .and_then(|m| { + m.row_type.iter().enumerate().find_map(|(index, col)| { + if col.name.eq_ignore_ascii_case("example") { + Some(index) + } else { + None + } + }) + }) + .unwrap_or(0); + + let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| { + let Some(example_value) = data_row.get(example_index) else { + return None; + }; + if example_value.is_null() { + return None; + } + + let parse_result = match example_value { + JsonValue::String(encoded_json) => serde_json::from_str::(encoded_json), + _ => serde_json::from_value::(example_value.clone()), + }; + + match parse_result { + Ok(spec) => Some(Example { + spec, + buffer: None, + context: None, + prompt: None, + predictions: Vec::new(), + score: Vec::new(), + state: None, + }), + Err(error) => { + let raw_json = serde_json::to_string_pretty(example_value) + .unwrap_or_else(|_| "".to_string()); + log::error!( + "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}" + ); + None + } + } + }); + + Ok(iter) +} + +async fn run_sql( + http_client: Arc, + base_url: &str, + token: &str, + request: &serde_json::Value, +) -> Result { + let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/')); + + let request_body = + serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?; + + let http_request = Request::builder() + .method(Method::POST) + .uri(url.as_str()) + .header("Authorization", format!("Bearer {token}")) + .header( + "X-Snowflake-Authorization-Token-Type", + "PROGRAMMATIC_ACCESS_TOKEN", + ) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .body(AsyncBody::from(request_body.clone()))?; + + let response = http_client + .send(http_request) + .await + .context("failed to send request to Snowflake SQL API")?; + + let status = response.status(); + let body_bytes = { + use futures::AsyncReadExt as _; + + let mut body = response.into_body(); + let mut bytes = Vec::new(); + body.read_to_end(&mut bytes) + .await + .context("failed to read Snowflake SQL API response body")?; + bytes + }; + + if !status.is_success() { + let body_text = String::from_utf8_lossy(&body_bytes); + anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text); + } + + serde_json::from_slice::(&body_bytes) + .context("failed to parse Snowflake SQL API response JSON") +} diff --git a/crates/edit_prediction_cli/src/split_commit.rs b/crates/edit_prediction_cli/src/split_commit.rs index 88be74511901a7704fdaa3934a01ad20abe3b032..64e5d2f2aa2a85530c5db614c94dc673f91dd83f 100644 --- a/crates/edit_prediction_cli/src/split_commit.rs +++ b/crates/edit_prediction_cli/src/split_commit.rs @@ -16,7 +16,7 @@ use std::fs; use std::io::{self, Read}; /// `ep split-commit` CLI args. -#[derive(Debug, Args)] +#[derive(Debug, Args, Clone)] pub struct SplitCommitArgs { /// Path to the commit file (use "-" for stdin) #[arg(long, short = 'c')] diff --git a/crates/edit_prediction_cli/src/synthesize.rs b/crates/edit_prediction_cli/src/synthesize.rs index b79f84b1c712867b01ed3e5a27b96bf0dd1b56e3..b4d4975a15f2f6b50772fd43bd76fdea59dfa515 100644 --- a/crates/edit_prediction_cli/src/synthesize.rs +++ b/crates/edit_prediction_cli/src/synthesize.rs @@ -108,7 +108,7 @@ pub async fn run_synthesize(config: SynthesizeConfig) -> Result<()> { std::os::windows::fs::symlink_dir(&*FAILED_EXAMPLES_DIR, &*LATEST_FAILED_EXAMPLES_DIR)?; let progress = Progress::global(); - progress.set_total_examples(config.count); + progress.set_total_steps(config.count); let clone_progress = progress.start(Step::Synthesize, "clone"); let repo_path = ensure_repo_cloned(&config.repo_url).await?;