Fix issues processing captured edit prediction examples (#46773)

Max Brunsfeld and Agus Zubiaga created

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <agus@zed.dev>

Change summary

Cargo.lock                                      |   5 
crates/edit_prediction/src/capture_example.rs   |   8 
crates/edit_prediction/src/udiff.rs             |  48 ++++
crates/edit_prediction_cli/Cargo.toml           |   1 
crates/edit_prediction_cli/src/load_project.rs  |  26 +
crates/edit_prediction_cli/src/main.rs          |  13 
crates/edit_prediction_cli/src/pull_examples.rs | 207 ++++++++++++++++++
7 files changed, 290 insertions(+), 18 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5292,6 +5292,7 @@ dependencies = [
  "dirs 4.0.0",
  "edit_prediction",
  "extension",
+ "flate2",
  "fs",
  "futures 0.3.31",
  "gpui",
@@ -6252,9 +6253,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
 
 [[package]]
 name = "flate2"
-version = "1.1.4"
+version = "1.1.8"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9"
+checksum = "b375d6465b98090a5f25b1c7703f3859783755aa9a80433b36e0379a3ec2f369"
 dependencies = [
  "crc32fast",
  "miniz_oxide",

crates/edit_prediction/src/capture_example.rs 🔗

@@ -13,6 +13,7 @@ use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::
 use text::{BufferSnapshot as TextBufferSnapshot, Point};
 
 pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
+pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
 
 pub fn capture_example(
     project: Entity<Project>,
@@ -232,10 +233,15 @@ fn generate_timestamp_name() -> String {
 }
 
 pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
+    let default_rate = if cx.is_staff() {
+        DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
+    } else {
+        DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
+    };
     let capture_rate = language::language_settings::all_language_settings(None, cx)
         .edit_predictions
         .example_capture_rate
-        .unwrap_or(DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS);
+        .unwrap_or(default_rate);
     cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
         && rand::random::<u16>() % 10_000 < capture_rate
 }

crates/edit_prediction/src/udiff.rs 🔗

@@ -214,6 +214,54 @@ pub fn extract_file_diff(full_diff: &str, file_path: &str) -> Result<String> {
     Ok(result)
 }
 
+pub fn strip_diff_path_prefix<'a>(diff: &'a str, prefix: &str) -> Cow<'a, str> {
+    if prefix.is_empty() {
+        return Cow::Borrowed(diff);
+    }
+
+    let prefix_with_slash = format!("{}/", prefix);
+    let mut needs_rewrite = false;
+
+    for line in diff.lines() {
+        match DiffLine::parse(line) {
+            DiffLine::OldPath { path } | DiffLine::NewPath { path } => {
+                if path.starts_with(&prefix_with_slash) {
+                    needs_rewrite = true;
+                    break;
+                }
+            }
+            _ => {}
+        }
+    }
+
+    if !needs_rewrite {
+        return Cow::Borrowed(diff);
+    }
+
+    let mut result = String::with_capacity(diff.len());
+    for line in diff.lines() {
+        match DiffLine::parse(line) {
+            DiffLine::OldPath { path } => {
+                let stripped = path
+                    .strip_prefix(&prefix_with_slash)
+                    .unwrap_or(path.as_ref());
+                result.push_str(&format!("--- a/{}\n", stripped));
+            }
+            DiffLine::NewPath { path } => {
+                let stripped = path
+                    .strip_prefix(&prefix_with_slash)
+                    .unwrap_or(path.as_ref());
+                result.push_str(&format!("+++ b/{}\n", stripped));
+            }
+            _ => {
+                result.push_str(line);
+                result.push('\n');
+            }
+        }
+    }
+
+    Cow::Owned(result)
+}
 /// Strip unnecessary git metadata lines from a diff, keeping only the lines
 /// needed for patch application: path headers (--- and +++), hunk headers (@@),
 /// and content lines (+, -, space).

crates/edit_prediction_cli/Cargo.toml 🔗

@@ -57,6 +57,7 @@ wasmtime.workspace = true
 zeta_prompt.workspace = true
 rand.workspace = true
 similar = "2.7.0"
+flate2 = "1.1.8"
 
 # Wasmtime is included as a dependency in order to enable the same
 # features that are enabled in Zed.

crates/edit_prediction_cli/src/load_project.rs 🔗

@@ -5,7 +5,7 @@ use crate::{
     progress::{InfoStyle, Progress, Step, StepProgress},
 };
 use anyhow::{Context as _, Result};
-use edit_prediction::udiff::{OpenedBuffers, refresh_worktree_entries};
+use edit_prediction::udiff::{OpenedBuffers, refresh_worktree_entries, strip_diff_path_prefix};
 use edit_prediction::{
     EditPredictionStore, cursor_excerpt::editable_and_context_ranges_for_cursor_position, zeta2,
 };
@@ -111,8 +111,16 @@ async fn cursor_position(
     }
 
     let cursor_path_str = example.spec.cursor_path.to_string_lossy();
+    // Also try cursor path with first component stripped - old examples may have
+    // paths like "zed/crates/foo.rs" instead of "crates/foo.rs".
+    let cursor_path_without_prefix: PathBuf =
+        example.spec.cursor_path.components().skip(1).collect();
+    let cursor_path_without_prefix_str = cursor_path_without_prefix.to_string_lossy();
+
     // We try open_buffers first because the file might be new and not saved to disk
-    let cursor_buffer = if let Some(buffer) = open_buffers.get(&cursor_path_str) {
+    let cursor_buffer = if let Some(buffer) = open_buffers.get(cursor_path_str.as_ref()) {
+        buffer.clone()
+    } else if let Some(buffer) = open_buffers.get(cursor_path_without_prefix_str.as_ref()) {
         buffer.clone()
     } else {
         // Since the worktree scanner is disabled, manually refresh entries for the cursor path.
@@ -122,7 +130,9 @@ async fn cursor_position(
 
         let cursor_path = project
             .read_with(cx, |project, cx| {
-                project.find_project_path(&example.spec.cursor_path, cx)
+                project
+                    .find_project_path(&example.spec.cursor_path, cx)
+                    .or_else(|| project.find_project_path(&cursor_path_without_prefix, cx))
             })
             .with_context(|| {
                 format!(
@@ -282,9 +292,13 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu
     }
     drop(repo_lock);
 
-    // Apply the uncommitted diff for this example.
     if !example.spec.uncommitted_diff.is_empty() {
         step_progress.set_substatus("applying diff");
+
+        // old examples had full paths in the uncommitted diff.
+        let uncommitted_diff =
+            strip_diff_path_prefix(&example.spec.uncommitted_diff, &repo_name.name);
+
         let mut apply_process = smol::process::Command::new("git")
             .current_dir(&worktree_path)
             .args(&["apply", "-"])
@@ -292,9 +306,7 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu
             .spawn()?;
 
         let mut stdin = apply_process.stdin.take().context("Failed to get stdin")?;
-        stdin
-            .write_all(example.spec.uncommitted_diff.as_bytes())
-            .await?;
+        stdin.write_all(uncommitted_diff.as_bytes()).await?;
         stdin.close().await?;
         drop(stdin);
 

crates/edit_prediction_cli/src/main.rs 🔗

@@ -21,7 +21,7 @@ use collections::HashSet;
 use edit_prediction::EditPredictionStore;
 use futures::channel::mpsc;
 use futures::{SinkExt as _, StreamExt as _};
-use gpui::{AppContext as _, Application};
+use gpui::{AppContext as _, Application, BackgroundExecutor};
 use zeta_prompt::ZetaVersion;
 
 use reqwest_client::ReqwestClient;
@@ -279,6 +279,7 @@ async fn load_examples(
     http_client: Arc<dyn http_client::HttpClient>,
     args: &EpArgs,
     output_path: Option<&PathBuf>,
+    background_executor: BackgroundExecutor,
 ) -> anyhow::Result<Vec<Example>> {
     let mut captured_after_timestamps = Vec::new();
     let mut file_inputs = Vec::new();
@@ -312,6 +313,7 @@ async fn load_examples(
             http_client,
             &captured_after_timestamps,
             max_rows_per_timestamp,
+            background_executor,
         )
         .await?;
         examples.append(&mut captured_examples);
@@ -465,8 +467,13 @@ fn main() {
 
         cx.spawn(async move |cx| {
             let result = async {
-                let mut examples =
-                    load_examples(app_state.client.http_client(), &args, output.as_ref()).await?;
+                let mut examples = load_examples(
+                    app_state.client.http_client(),
+                    &args,
+                    output.as_ref(),
+                    cx.background_executor().clone(),
+                )
+                .await?;
 
                 match &command {
                     Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {

crates/edit_prediction_cli/src/pull_examples.rs 🔗

@@ -1,9 +1,13 @@
 use anyhow::{Context as _, Result};
+use flate2::read::GzDecoder;
+use gpui::BackgroundExecutor;
 use http_client::{AsyncBody, HttpClient, Method, Request};
 use indoc::indoc;
 use serde::Deserialize;
 use serde_json::{Value as JsonValue, json};
+use std::io::Read;
 use std::sync::Arc;
+use std::time::Duration;
 
 use crate::{
     example::Example,
@@ -12,9 +16,12 @@ use crate::{
 use edit_prediction::example_spec::ExampleSpec;
 
 const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
+const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
 const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
 
 const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
+const POLL_INTERVAL: Duration = Duration::from_secs(2);
+const MAX_POLL_ATTEMPTS: usize = 120;
 
 /// Parse an input token of the form `captured-after:{timestamp}`.
 pub fn parse_captured_after_input(input: &str) -> Option<&str> {
@@ -25,6 +32,7 @@ pub async fn fetch_captured_examples_after(
     http_client: Arc<dyn HttpClient>,
     after_timestamps: &[String],
     max_rows_per_timestamp: usize,
+    background_executor: BackgroundExecutor,
 ) -> Result<Vec<Example>> {
     if after_timestamps.is_empty() {
         return Ok(Vec::new());
@@ -70,13 +78,60 @@ pub async fn fetch_captured_examples_after(
             }
         });
 
-        let response = run_sql(http_client.clone(), &base_url, &token, &request).await?;
-
-        step_progress.set_info(format!("{} rows", response.data.len()), InfoStyle::Normal);
+        let response = run_sql_with_polling(
+            http_client.clone(),
+            &base_url,
+            &token,
+            &request,
+            &step_progress,
+            background_executor.clone(),
+        )
+        .await?;
+
+        let total_rows = response
+            .result_set_meta_data
+            .as_ref()
+            .and_then(|m| m.num_rows)
+            .unwrap_or(response.data.len() as i64);
+
+        let num_partitions = response
+            .result_set_meta_data
+            .as_ref()
+            .map(|m| m.partition_info.len())
+            .unwrap_or(1)
+            .max(1);
+
+        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
         step_progress.set_substatus("parsing");
 
         all_examples.extend(examples_from_response(&response)?);
 
+        if num_partitions > 1 {
+            let statement_handle = response
+                .statement_handle
+                .as_ref()
+                .context("response has multiple partitions but no statementHandle")?;
+
+            for partition in 1..num_partitions {
+                step_progress.set_substatus(format!(
+                    "fetching partition {}/{}",
+                    partition + 1,
+                    num_partitions
+                ));
+
+                let partition_response = fetch_partition(
+                    http_client.clone(),
+                    &base_url,
+                    &token,
+                    statement_handle,
+                    partition,
+                )
+                .await?;
+
+                all_examples.extend(examples_from_response(&partition_response)?);
+            }
+        }
+
         step_progress.set_substatus("done");
     }
 
@@ -84,6 +139,7 @@ pub async fn fetch_captured_examples_after(
 }
 
 #[derive(Debug, Clone, Deserialize)]
+#[serde(rename_all = "camelCase")]
 struct SnowflakeStatementResponse {
     #[serde(default)]
     data: Vec<Vec<JsonValue>>,
@@ -93,14 +149,25 @@ struct SnowflakeStatementResponse {
     code: Option<String>,
     #[serde(default)]
     message: Option<String>,
+    #[serde(default)]
+    statement_handle: Option<String>,
 }
 
 #[derive(Debug, Clone, Deserialize)]
+#[serde(rename_all = "camelCase")]
 struct SnowflakeResultSetMetaData {
     #[serde(default, rename = "rowType")]
     row_type: Vec<SnowflakeColumnMeta>,
+    #[serde(default)]
+    num_rows: Option<i64>,
+    #[serde(default)]
+    partition_info: Vec<SnowflakePartitionInfo>,
 }
 
+#[derive(Debug, Clone, Deserialize)]
+#[serde(rename_all = "camelCase")]
+struct SnowflakePartitionInfo {}
+
 #[derive(Debug, Clone, Deserialize)]
 struct SnowflakeColumnMeta {
     #[serde(default)]
@@ -109,7 +176,7 @@ struct SnowflakeColumnMeta {
 
 fn examples_from_response(
     response: &SnowflakeStatementResponse,
-) -> Result<impl Iterator<Item = Example>> {
+) -> Result<impl Iterator<Item = Example> + '_> {
     if let Some(code) = &response.code {
         if code != SNOWFLAKE_SUCCESS_CODE {
             anyhow::bail!(
@@ -169,6 +236,136 @@ fn examples_from_response(
     Ok(iter)
 }
 
+async fn run_sql_with_polling(
+    http_client: Arc<dyn HttpClient>,
+    base_url: &str,
+    token: &str,
+    request: &serde_json::Value,
+    step_progress: &crate::progress::StepProgress,
+    background_executor: BackgroundExecutor,
+) -> Result<SnowflakeStatementResponse> {
+    let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
+
+    if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
+        let statement_handle = response
+            .statement_handle
+            .as_ref()
+            .context("async query response missing statementHandle")?
+            .clone();
+
+        for attempt in 1..=MAX_POLL_ATTEMPTS {
+            step_progress.set_substatus(format!("polling ({attempt})"));
+
+            background_executor.timer(POLL_INTERVAL).await;
+
+            response =
+                fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
+
+            if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
+                break;
+            }
+        }
+
+        if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
+            anyhow::bail!(
+                "query still running after {} poll attempts ({} seconds)",
+                MAX_POLL_ATTEMPTS,
+                MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
+            );
+        }
+    }
+
+    Ok(response)
+}
+
+async fn fetch_partition(
+    http_client: Arc<dyn HttpClient>,
+    base_url: &str,
+    token: &str,
+    statement_handle: &str,
+    partition: usize,
+) -> Result<SnowflakeStatementResponse> {
+    let url = format!(
+        "{}/api/v2/statements/{}?partition={}",
+        base_url.trim_end_matches('/'),
+        statement_handle,
+        partition
+    );
+
+    let http_request = Request::builder()
+        .method(Method::GET)
+        .uri(url.as_str())
+        .header("Authorization", format!("Bearer {token}"))
+        .header(
+            "X-Snowflake-Authorization-Token-Type",
+            "PROGRAMMATIC_ACCESS_TOKEN",
+        )
+        .header("Accept", "application/json")
+        .header("Accept-Encoding", "gzip")
+        .body(AsyncBody::empty())?;
+
+    let response = http_client
+        .send(http_request)
+        .await
+        .context("failed to send partition request to Snowflake SQL API")?;
+
+    let status = response.status();
+    let content_encoding = response
+        .headers()
+        .get("content-encoding")
+        .and_then(|v| v.to_str().ok())
+        .map(|s| s.to_lowercase());
+
+    let body_bytes = {
+        use futures::AsyncReadExt as _;
+
+        let mut body = response.into_body();
+        let mut bytes = Vec::new();
+        body.read_to_end(&mut bytes)
+            .await
+            .context("failed to read Snowflake SQL API partition response body")?;
+        bytes
+    };
+
+    let body_bytes = if content_encoding.as_deref() == Some("gzip") {
+        let mut decoder = GzDecoder::new(&body_bytes[..]);
+        let mut decompressed = Vec::new();
+        decoder
+            .read_to_end(&mut decompressed)
+            .context("failed to decompress gzip response")?;
+        decompressed
+    } else {
+        body_bytes
+    };
+
+    if !status.is_success() && status.as_u16() != 202 {
+        let body_text = String::from_utf8_lossy(&body_bytes);
+        anyhow::bail!(
+            "snowflake sql api partition request http {}: {}",
+            status.as_u16(),
+            body_text
+        );
+    }
+
+    if body_bytes.is_empty() {
+        anyhow::bail!(
+            "snowflake sql api partition {} returned empty response body (http {})",
+            partition,
+            status.as_u16()
+        );
+    }
+
+    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
+        let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
+        format!(
+            "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
+            partition,
+            status.as_u16(),
+            body_preview
+        )
+    })
+}
+
 async fn run_sql(
     http_client: Arc<dyn HttpClient>,
     base_url: &str,
@@ -209,7 +406,7 @@ async fn run_sql(
         bytes
     };
 
-    if !status.is_success() {
+    if !status.is_success() && status.as_u16() != 202 {
         let body_text = String::from_utf8_lossy(&body_bytes);
         anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
     }