diff --git a/Cargo.lock b/Cargo.lock index b5fd29820fe315c5a7e8a89544020fed3130cfd0..17916306852a0a596f7a2e36ba95479bb857adba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5292,6 +5292,7 @@ dependencies = [ "dirs 4.0.0", "edit_prediction", "extension", + "flate2", "fs", "futures 0.3.31", "gpui", @@ -6252,9 +6253,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flate2" -version = "1.1.4" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9" +checksum = "b375d6465b98090a5f25b1c7703f3859783755aa9a80433b36e0379a3ec2f369" dependencies = [ "crc32fast", "miniz_oxide", diff --git a/crates/edit_prediction/src/capture_example.rs b/crates/edit_prediction/src/capture_example.rs index f5d6088c7cb70494ea89418fe2e0dbcd79ec0b57..f172c532bf9f7a3bd5a0744d4752d1877847ad1e 100644 --- a/crates/edit_prediction/src/capture_example.rs +++ b/crates/edit_prediction/src/capture_example.rs @@ -13,6 +13,7 @@ use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync:: use text::{BufferSnapshot as TextBufferSnapshot, Point}; pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10; +pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100; pub fn capture_example( project: Entity, @@ -232,10 +233,15 @@ fn generate_timestamp_name() -> String { } pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool { + let default_rate = if cx.is_staff() { + DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS + } else { + DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS + }; let capture_rate = language::language_settings::all_language_settings(None, cx) .edit_predictions .example_capture_rate - .unwrap_or(DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS); + .unwrap_or(default_rate); cx.has_flag::() && rand::random::() % 10_000 < capture_rate } diff --git a/crates/edit_prediction/src/udiff.rs b/crates/edit_prediction/src/udiff.rs index 163473804dfc50bfcda862261409c7006adf9b00..015bef2f0a0af35b1d807ec0596ae08542ae5ba0 100644 --- a/crates/edit_prediction/src/udiff.rs +++ b/crates/edit_prediction/src/udiff.rs @@ -214,6 +214,54 @@ pub fn extract_file_diff(full_diff: &str, file_path: &str) -> Result { Ok(result) } +pub fn strip_diff_path_prefix<'a>(diff: &'a str, prefix: &str) -> Cow<'a, str> { + if prefix.is_empty() { + return Cow::Borrowed(diff); + } + + let prefix_with_slash = format!("{}/", prefix); + let mut needs_rewrite = false; + + for line in diff.lines() { + match DiffLine::parse(line) { + DiffLine::OldPath { path } | DiffLine::NewPath { path } => { + if path.starts_with(&prefix_with_slash) { + needs_rewrite = true; + break; + } + } + _ => {} + } + } + + if !needs_rewrite { + return Cow::Borrowed(diff); + } + + let mut result = String::with_capacity(diff.len()); + for line in diff.lines() { + match DiffLine::parse(line) { + DiffLine::OldPath { path } => { + let stripped = path + .strip_prefix(&prefix_with_slash) + .unwrap_or(path.as_ref()); + result.push_str(&format!("--- a/{}\n", stripped)); + } + DiffLine::NewPath { path } => { + let stripped = path + .strip_prefix(&prefix_with_slash) + .unwrap_or(path.as_ref()); + result.push_str(&format!("+++ b/{}\n", stripped)); + } + _ => { + result.push_str(line); + result.push('\n'); + } + } + } + + Cow::Owned(result) +} /// Strip unnecessary git metadata lines from a diff, keeping only the lines /// needed for patch application: path headers (--- and +++), hunk headers (@@), /// and content lines (+, -, space). diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index b8d35c2b8be2a961aa0267883013be0c058d1697..da65c68a76b87310640b299a4a3f2d79e0200266 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -57,6 +57,7 @@ wasmtime.workspace = true zeta_prompt.workspace = true rand.workspace = true similar = "2.7.0" +flate2 = "1.1.8" # Wasmtime is included as a dependency in order to enable the same # features that are enabled in Zed. diff --git a/crates/edit_prediction_cli/src/load_project.rs b/crates/edit_prediction_cli/src/load_project.rs index 8fda65452a4badf6dc21277058e413a29d000e98..b37fa79ee35c55773c9745d4833fe3cbd19f4120 100644 --- a/crates/edit_prediction_cli/src/load_project.rs +++ b/crates/edit_prediction_cli/src/load_project.rs @@ -5,7 +5,7 @@ use crate::{ progress::{InfoStyle, Progress, Step, StepProgress}, }; use anyhow::{Context as _, Result}; -use edit_prediction::udiff::{OpenedBuffers, refresh_worktree_entries}; +use edit_prediction::udiff::{OpenedBuffers, refresh_worktree_entries, strip_diff_path_prefix}; use edit_prediction::{ EditPredictionStore, cursor_excerpt::editable_and_context_ranges_for_cursor_position, zeta2, }; @@ -111,8 +111,16 @@ async fn cursor_position( } let cursor_path_str = example.spec.cursor_path.to_string_lossy(); + // Also try cursor path with first component stripped - old examples may have + // paths like "zed/crates/foo.rs" instead of "crates/foo.rs". + let cursor_path_without_prefix: PathBuf = + example.spec.cursor_path.components().skip(1).collect(); + let cursor_path_without_prefix_str = cursor_path_without_prefix.to_string_lossy(); + // We try open_buffers first because the file might be new and not saved to disk - let cursor_buffer = if let Some(buffer) = open_buffers.get(&cursor_path_str) { + let cursor_buffer = if let Some(buffer) = open_buffers.get(cursor_path_str.as_ref()) { + buffer.clone() + } else if let Some(buffer) = open_buffers.get(cursor_path_without_prefix_str.as_ref()) { buffer.clone() } else { // Since the worktree scanner is disabled, manually refresh entries for the cursor path. @@ -122,7 +130,9 @@ async fn cursor_position( let cursor_path = project .read_with(cx, |project, cx| { - project.find_project_path(&example.spec.cursor_path, cx) + project + .find_project_path(&example.spec.cursor_path, cx) + .or_else(|| project.find_project_path(&cursor_path_without_prefix, cx)) }) .with_context(|| { format!( @@ -282,9 +292,13 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu } drop(repo_lock); - // Apply the uncommitted diff for this example. if !example.spec.uncommitted_diff.is_empty() { step_progress.set_substatus("applying diff"); + + // old examples had full paths in the uncommitted diff. + let uncommitted_diff = + strip_diff_path_prefix(&example.spec.uncommitted_diff, &repo_name.name); + let mut apply_process = smol::process::Command::new("git") .current_dir(&worktree_path) .args(&["apply", "-"]) @@ -292,9 +306,7 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu .spawn()?; let mut stdin = apply_process.stdin.take().context("Failed to get stdin")?; - stdin - .write_all(example.spec.uncommitted_diff.as_bytes()) - .await?; + stdin.write_all(uncommitted_diff.as_bytes()).await?; stdin.close().await?; drop(stdin); diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 1829ad18d80bee2a1c28c7da7f68ea910ba56d74..34040136cf3b48ff180ff232296f3016300b4161 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -21,7 +21,7 @@ use collections::HashSet; use edit_prediction::EditPredictionStore; use futures::channel::mpsc; use futures::{SinkExt as _, StreamExt as _}; -use gpui::{AppContext as _, Application}; +use gpui::{AppContext as _, Application, BackgroundExecutor}; use zeta_prompt::ZetaVersion; use reqwest_client::ReqwestClient; @@ -279,6 +279,7 @@ async fn load_examples( http_client: Arc, args: &EpArgs, output_path: Option<&PathBuf>, + background_executor: BackgroundExecutor, ) -> anyhow::Result> { let mut captured_after_timestamps = Vec::new(); let mut file_inputs = Vec::new(); @@ -312,6 +313,7 @@ async fn load_examples( http_client, &captured_after_timestamps, max_rows_per_timestamp, + background_executor, ) .await?; examples.append(&mut captured_examples); @@ -465,8 +467,13 @@ fn main() { cx.spawn(async move |cx| { let result = async { - let mut examples = - load_examples(app_state.client.http_client(), &args, output.as_ref()).await?; + let mut examples = load_examples( + app_state.client.http_client(), + &args, + output.as_ref(), + cx.background_executor().clone(), + ) + .await?; match &command { Command::Predict(args) | Command::Score(args) | Command::Eval(args) => { diff --git a/crates/edit_prediction_cli/src/pull_examples.rs b/crates/edit_prediction_cli/src/pull_examples.rs index 91ffa53c4453d918082a6a1e7e9d84abb7d60770..eb0c5cbcd42cf1188ce555cdc76ee879356966bd 100644 --- a/crates/edit_prediction_cli/src/pull_examples.rs +++ b/crates/edit_prediction_cli/src/pull_examples.rs @@ -1,9 +1,13 @@ use anyhow::{Context as _, Result}; +use flate2::read::GzDecoder; +use gpui::BackgroundExecutor; use http_client::{AsyncBody, HttpClient, Method, Request}; use indoc::indoc; use serde::Deserialize; use serde_json::{Value as JsonValue, json}; +use std::io::Read; use std::sync::Arc; +use std::time::Duration; use crate::{ example::Example, @@ -12,9 +16,12 @@ use crate::{ use edit_prediction::example_spec::ExampleSpec; const SNOWFLAKE_SUCCESS_CODE: &str = "090001"; +const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334"; const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured"; const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120; +const POLL_INTERVAL: Duration = Duration::from_secs(2); +const MAX_POLL_ATTEMPTS: usize = 120; /// Parse an input token of the form `captured-after:{timestamp}`. pub fn parse_captured_after_input(input: &str) -> Option<&str> { @@ -25,6 +32,7 @@ pub async fn fetch_captured_examples_after( http_client: Arc, after_timestamps: &[String], max_rows_per_timestamp: usize, + background_executor: BackgroundExecutor, ) -> Result> { if after_timestamps.is_empty() { return Ok(Vec::new()); @@ -70,13 +78,60 @@ pub async fn fetch_captured_examples_after( } }); - let response = run_sql(http_client.clone(), &base_url, &token, &request).await?; - - step_progress.set_info(format!("{} rows", response.data.len()), InfoStyle::Normal); + let response = run_sql_with_polling( + http_client.clone(), + &base_url, + &token, + &request, + &step_progress, + background_executor.clone(), + ) + .await?; + + let total_rows = response + .result_set_meta_data + .as_ref() + .and_then(|m| m.num_rows) + .unwrap_or(response.data.len() as i64); + + let num_partitions = response + .result_set_meta_data + .as_ref() + .map(|m| m.partition_info.len()) + .unwrap_or(1) + .max(1); + + step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal); step_progress.set_substatus("parsing"); all_examples.extend(examples_from_response(&response)?); + if num_partitions > 1 { + let statement_handle = response + .statement_handle + .as_ref() + .context("response has multiple partitions but no statementHandle")?; + + for partition in 1..num_partitions { + step_progress.set_substatus(format!( + "fetching partition {}/{}", + partition + 1, + num_partitions + )); + + let partition_response = fetch_partition( + http_client.clone(), + &base_url, + &token, + statement_handle, + partition, + ) + .await?; + + all_examples.extend(examples_from_response(&partition_response)?); + } + } + step_progress.set_substatus("done"); } @@ -84,6 +139,7 @@ pub async fn fetch_captured_examples_after( } #[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] struct SnowflakeStatementResponse { #[serde(default)] data: Vec>, @@ -93,14 +149,25 @@ struct SnowflakeStatementResponse { code: Option, #[serde(default)] message: Option, + #[serde(default)] + statement_handle: Option, } #[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] struct SnowflakeResultSetMetaData { #[serde(default, rename = "rowType")] row_type: Vec, + #[serde(default)] + num_rows: Option, + #[serde(default)] + partition_info: Vec, } +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +struct SnowflakePartitionInfo {} + #[derive(Debug, Clone, Deserialize)] struct SnowflakeColumnMeta { #[serde(default)] @@ -109,7 +176,7 @@ struct SnowflakeColumnMeta { fn examples_from_response( response: &SnowflakeStatementResponse, -) -> Result> { +) -> Result + '_> { if let Some(code) = &response.code { if code != SNOWFLAKE_SUCCESS_CODE { anyhow::bail!( @@ -169,6 +236,136 @@ fn examples_from_response( Ok(iter) } +async fn run_sql_with_polling( + http_client: Arc, + base_url: &str, + token: &str, + request: &serde_json::Value, + step_progress: &crate::progress::StepProgress, + background_executor: BackgroundExecutor, +) -> Result { + let mut response = run_sql(http_client.clone(), base_url, token, request).await?; + + if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) { + let statement_handle = response + .statement_handle + .as_ref() + .context("async query response missing statementHandle")? + .clone(); + + for attempt in 1..=MAX_POLL_ATTEMPTS { + step_progress.set_substatus(format!("polling ({attempt})")); + + background_executor.timer(POLL_INTERVAL).await; + + response = + fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?; + + if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) { + break; + } + } + + if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) { + anyhow::bail!( + "query still running after {} poll attempts ({} seconds)", + MAX_POLL_ATTEMPTS, + MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs() + ); + } + } + + Ok(response) +} + +async fn fetch_partition( + http_client: Arc, + base_url: &str, + token: &str, + statement_handle: &str, + partition: usize, +) -> Result { + let url = format!( + "{}/api/v2/statements/{}?partition={}", + base_url.trim_end_matches('/'), + statement_handle, + partition + ); + + let http_request = Request::builder() + .method(Method::GET) + .uri(url.as_str()) + .header("Authorization", format!("Bearer {token}")) + .header( + "X-Snowflake-Authorization-Token-Type", + "PROGRAMMATIC_ACCESS_TOKEN", + ) + .header("Accept", "application/json") + .header("Accept-Encoding", "gzip") + .body(AsyncBody::empty())?; + + let response = http_client + .send(http_request) + .await + .context("failed to send partition request to Snowflake SQL API")?; + + let status = response.status(); + let content_encoding = response + .headers() + .get("content-encoding") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_lowercase()); + + let body_bytes = { + use futures::AsyncReadExt as _; + + let mut body = response.into_body(); + let mut bytes = Vec::new(); + body.read_to_end(&mut bytes) + .await + .context("failed to read Snowflake SQL API partition response body")?; + bytes + }; + + let body_bytes = if content_encoding.as_deref() == Some("gzip") { + let mut decoder = GzDecoder::new(&body_bytes[..]); + let mut decompressed = Vec::new(); + decoder + .read_to_end(&mut decompressed) + .context("failed to decompress gzip response")?; + decompressed + } else { + body_bytes + }; + + if !status.is_success() && status.as_u16() != 202 { + let body_text = String::from_utf8_lossy(&body_bytes); + anyhow::bail!( + "snowflake sql api partition request http {}: {}", + status.as_u16(), + body_text + ); + } + + if body_bytes.is_empty() { + anyhow::bail!( + "snowflake sql api partition {} returned empty response body (http {})", + partition, + status.as_u16() + ); + } + + serde_json::from_slice::(&body_bytes).with_context(|| { + let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]); + format!( + "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}", + partition, + status.as_u16(), + body_preview + ) + }) +} + async fn run_sql( http_client: Arc, base_url: &str, @@ -209,7 +406,7 @@ async fn run_sql( bytes }; - if !status.is_success() { + if !status.is_success() && status.as_u16() != 202 { let body_text = String::from_utf8_lossy(&body_bytes); anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text); }