Detailed changes
@@ -5292,6 +5292,7 @@ dependencies = [
"dirs 4.0.0",
"edit_prediction",
"extension",
+ "flate2",
"fs",
"futures 0.3.31",
"gpui",
@@ -6252,9 +6253,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]]
name = "flate2"
-version = "1.1.4"
+version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9"
+checksum = "b375d6465b98090a5f25b1c7703f3859783755aa9a80433b36e0379a3ec2f369"
dependencies = [
"crc32fast",
"miniz_oxide",
@@ -13,6 +13,7 @@ use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::
use text::{BufferSnapshot as TextBufferSnapshot, Point};
pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
+pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
pub fn capture_example(
project: Entity<Project>,
@@ -232,10 +233,15 @@ fn generate_timestamp_name() -> String {
}
pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
+ let default_rate = if cx.is_staff() {
+ DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
+ } else {
+ DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
+ };
let capture_rate = language::language_settings::all_language_settings(None, cx)
.edit_predictions
.example_capture_rate
- .unwrap_or(DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS);
+ .unwrap_or(default_rate);
cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
&& rand::random::<u16>() % 10_000 < capture_rate
}
@@ -214,6 +214,54 @@ pub fn extract_file_diff(full_diff: &str, file_path: &str) -> Result<String> {
Ok(result)
}
+pub fn strip_diff_path_prefix<'a>(diff: &'a str, prefix: &str) -> Cow<'a, str> {
+ if prefix.is_empty() {
+ return Cow::Borrowed(diff);
+ }
+
+ let prefix_with_slash = format!("{}/", prefix);
+ let mut needs_rewrite = false;
+
+ for line in diff.lines() {
+ match DiffLine::parse(line) {
+ DiffLine::OldPath { path } | DiffLine::NewPath { path } => {
+ if path.starts_with(&prefix_with_slash) {
+ needs_rewrite = true;
+ break;
+ }
+ }
+ _ => {}
+ }
+ }
+
+ if !needs_rewrite {
+ return Cow::Borrowed(diff);
+ }
+
+ let mut result = String::with_capacity(diff.len());
+ for line in diff.lines() {
+ match DiffLine::parse(line) {
+ DiffLine::OldPath { path } => {
+ let stripped = path
+ .strip_prefix(&prefix_with_slash)
+ .unwrap_or(path.as_ref());
+ result.push_str(&format!("--- a/{}\n", stripped));
+ }
+ DiffLine::NewPath { path } => {
+ let stripped = path
+ .strip_prefix(&prefix_with_slash)
+ .unwrap_or(path.as_ref());
+ result.push_str(&format!("+++ b/{}\n", stripped));
+ }
+ _ => {
+ result.push_str(line);
+ result.push('\n');
+ }
+ }
+ }
+
+ Cow::Owned(result)
+}
/// Strip unnecessary git metadata lines from a diff, keeping only the lines
/// needed for patch application: path headers (--- and +++), hunk headers (@@),
/// and content lines (+, -, space).
@@ -57,6 +57,7 @@ wasmtime.workspace = true
zeta_prompt.workspace = true
rand.workspace = true
similar = "2.7.0"
+flate2 = "1.1.8"
# Wasmtime is included as a dependency in order to enable the same
# features that are enabled in Zed.
@@ -5,7 +5,7 @@ use crate::{
progress::{InfoStyle, Progress, Step, StepProgress},
};
use anyhow::{Context as _, Result};
-use edit_prediction::udiff::{OpenedBuffers, refresh_worktree_entries};
+use edit_prediction::udiff::{OpenedBuffers, refresh_worktree_entries, strip_diff_path_prefix};
use edit_prediction::{
EditPredictionStore, cursor_excerpt::editable_and_context_ranges_for_cursor_position, zeta2,
};
@@ -111,8 +111,16 @@ async fn cursor_position(
}
let cursor_path_str = example.spec.cursor_path.to_string_lossy();
+ // Also try cursor path with first component stripped - old examples may have
+ // paths like "zed/crates/foo.rs" instead of "crates/foo.rs".
+ let cursor_path_without_prefix: PathBuf =
+ example.spec.cursor_path.components().skip(1).collect();
+ let cursor_path_without_prefix_str = cursor_path_without_prefix.to_string_lossy();
+
// We try open_buffers first because the file might be new and not saved to disk
- let cursor_buffer = if let Some(buffer) = open_buffers.get(&cursor_path_str) {
+ let cursor_buffer = if let Some(buffer) = open_buffers.get(cursor_path_str.as_ref()) {
+ buffer.clone()
+ } else if let Some(buffer) = open_buffers.get(cursor_path_without_prefix_str.as_ref()) {
buffer.clone()
} else {
// Since the worktree scanner is disabled, manually refresh entries for the cursor path.
@@ -122,7 +130,9 @@ async fn cursor_position(
let cursor_path = project
.read_with(cx, |project, cx| {
- project.find_project_path(&example.spec.cursor_path, cx)
+ project
+ .find_project_path(&example.spec.cursor_path, cx)
+ .or_else(|| project.find_project_path(&cursor_path_without_prefix, cx))
})
.with_context(|| {
format!(
@@ -282,9 +292,13 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu
}
drop(repo_lock);
- // Apply the uncommitted diff for this example.
if !example.spec.uncommitted_diff.is_empty() {
step_progress.set_substatus("applying diff");
+
+ // old examples had full paths in the uncommitted diff.
+ let uncommitted_diff =
+ strip_diff_path_prefix(&example.spec.uncommitted_diff, &repo_name.name);
+
let mut apply_process = smol::process::Command::new("git")
.current_dir(&worktree_path)
.args(&["apply", "-"])
@@ -292,9 +306,7 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu
.spawn()?;
let mut stdin = apply_process.stdin.take().context("Failed to get stdin")?;
- stdin
- .write_all(example.spec.uncommitted_diff.as_bytes())
- .await?;
+ stdin.write_all(uncommitted_diff.as_bytes()).await?;
stdin.close().await?;
drop(stdin);
@@ -21,7 +21,7 @@ use collections::HashSet;
use edit_prediction::EditPredictionStore;
use futures::channel::mpsc;
use futures::{SinkExt as _, StreamExt as _};
-use gpui::{AppContext as _, Application};
+use gpui::{AppContext as _, Application, BackgroundExecutor};
use zeta_prompt::ZetaVersion;
use reqwest_client::ReqwestClient;
@@ -279,6 +279,7 @@ async fn load_examples(
http_client: Arc<dyn http_client::HttpClient>,
args: &EpArgs,
output_path: Option<&PathBuf>,
+ background_executor: BackgroundExecutor,
) -> anyhow::Result<Vec<Example>> {
let mut captured_after_timestamps = Vec::new();
let mut file_inputs = Vec::new();
@@ -312,6 +313,7 @@ async fn load_examples(
http_client,
&captured_after_timestamps,
max_rows_per_timestamp,
+ background_executor,
)
.await?;
examples.append(&mut captured_examples);
@@ -465,8 +467,13 @@ fn main() {
cx.spawn(async move |cx| {
let result = async {
- let mut examples =
- load_examples(app_state.client.http_client(), &args, output.as_ref()).await?;
+ let mut examples = load_examples(
+ app_state.client.http_client(),
+ &args,
+ output.as_ref(),
+ cx.background_executor().clone(),
+ )
+ .await?;
match &command {
Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
@@ -1,9 +1,13 @@
use anyhow::{Context as _, Result};
+use flate2::read::GzDecoder;
+use gpui::BackgroundExecutor;
use http_client::{AsyncBody, HttpClient, Method, Request};
use indoc::indoc;
use serde::Deserialize;
use serde_json::{Value as JsonValue, json};
+use std::io::Read;
use std::sync::Arc;
+use std::time::Duration;
use crate::{
example::Example,
@@ -12,9 +16,12 @@ use crate::{
use edit_prediction::example_spec::ExampleSpec;
const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
+const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
+const POLL_INTERVAL: Duration = Duration::from_secs(2);
+const MAX_POLL_ATTEMPTS: usize = 120;
/// Parse an input token of the form `captured-after:{timestamp}`.
pub fn parse_captured_after_input(input: &str) -> Option<&str> {
@@ -25,6 +32,7 @@ pub async fn fetch_captured_examples_after(
http_client: Arc<dyn HttpClient>,
after_timestamps: &[String],
max_rows_per_timestamp: usize,
+ background_executor: BackgroundExecutor,
) -> Result<Vec<Example>> {
if after_timestamps.is_empty() {
return Ok(Vec::new());
@@ -70,13 +78,60 @@ pub async fn fetch_captured_examples_after(
}
});
- let response = run_sql(http_client.clone(), &base_url, &token, &request).await?;
-
- step_progress.set_info(format!("{} rows", response.data.len()), InfoStyle::Normal);
+ let response = run_sql_with_polling(
+ http_client.clone(),
+ &base_url,
+ &token,
+ &request,
+ &step_progress,
+ background_executor.clone(),
+ )
+ .await?;
+
+ let total_rows = response
+ .result_set_meta_data
+ .as_ref()
+ .and_then(|m| m.num_rows)
+ .unwrap_or(response.data.len() as i64);
+
+ let num_partitions = response
+ .result_set_meta_data
+ .as_ref()
+ .map(|m| m.partition_info.len())
+ .unwrap_or(1)
+ .max(1);
+
+ step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
step_progress.set_substatus("parsing");
all_examples.extend(examples_from_response(&response)?);
+ if num_partitions > 1 {
+ let statement_handle = response
+ .statement_handle
+ .as_ref()
+ .context("response has multiple partitions but no statementHandle")?;
+
+ for partition in 1..num_partitions {
+ step_progress.set_substatus(format!(
+ "fetching partition {}/{}",
+ partition + 1,
+ num_partitions
+ ));
+
+ let partition_response = fetch_partition(
+ http_client.clone(),
+ &base_url,
+ &token,
+ statement_handle,
+ partition,
+ )
+ .await?;
+
+ all_examples.extend(examples_from_response(&partition_response)?);
+ }
+ }
+
step_progress.set_substatus("done");
}
@@ -84,6 +139,7 @@ pub async fn fetch_captured_examples_after(
}
#[derive(Debug, Clone, Deserialize)]
+#[serde(rename_all = "camelCase")]
struct SnowflakeStatementResponse {
#[serde(default)]
data: Vec<Vec<JsonValue>>,
@@ -93,14 +149,25 @@ struct SnowflakeStatementResponse {
code: Option<String>,
#[serde(default)]
message: Option<String>,
+ #[serde(default)]
+ statement_handle: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
+#[serde(rename_all = "camelCase")]
struct SnowflakeResultSetMetaData {
#[serde(default, rename = "rowType")]
row_type: Vec<SnowflakeColumnMeta>,
+ #[serde(default)]
+ num_rows: Option<i64>,
+ #[serde(default)]
+ partition_info: Vec<SnowflakePartitionInfo>,
}
+#[derive(Debug, Clone, Deserialize)]
+#[serde(rename_all = "camelCase")]
+struct SnowflakePartitionInfo {}
+
#[derive(Debug, Clone, Deserialize)]
struct SnowflakeColumnMeta {
#[serde(default)]
@@ -109,7 +176,7 @@ struct SnowflakeColumnMeta {
fn examples_from_response(
response: &SnowflakeStatementResponse,
-) -> Result<impl Iterator<Item = Example>> {
+) -> Result<impl Iterator<Item = Example> + '_> {
if let Some(code) = &response.code {
if code != SNOWFLAKE_SUCCESS_CODE {
anyhow::bail!(
@@ -169,6 +236,136 @@ fn examples_from_response(
Ok(iter)
}
+async fn run_sql_with_polling(
+ http_client: Arc<dyn HttpClient>,
+ base_url: &str,
+ token: &str,
+ request: &serde_json::Value,
+ step_progress: &crate::progress::StepProgress,
+ background_executor: BackgroundExecutor,
+) -> Result<SnowflakeStatementResponse> {
+ let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
+
+ if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
+ let statement_handle = response
+ .statement_handle
+ .as_ref()
+ .context("async query response missing statementHandle")?
+ .clone();
+
+ for attempt in 1..=MAX_POLL_ATTEMPTS {
+ step_progress.set_substatus(format!("polling ({attempt})"));
+
+ background_executor.timer(POLL_INTERVAL).await;
+
+ response =
+ fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
+
+ if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
+ break;
+ }
+ }
+
+ if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
+ anyhow::bail!(
+ "query still running after {} poll attempts ({} seconds)",
+ MAX_POLL_ATTEMPTS,
+ MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
+ );
+ }
+ }
+
+ Ok(response)
+}
+
+async fn fetch_partition(
+ http_client: Arc<dyn HttpClient>,
+ base_url: &str,
+ token: &str,
+ statement_handle: &str,
+ partition: usize,
+) -> Result<SnowflakeStatementResponse> {
+ let url = format!(
+ "{}/api/v2/statements/{}?partition={}",
+ base_url.trim_end_matches('/'),
+ statement_handle,
+ partition
+ );
+
+ let http_request = Request::builder()
+ .method(Method::GET)
+ .uri(url.as_str())
+ .header("Authorization", format!("Bearer {token}"))
+ .header(
+ "X-Snowflake-Authorization-Token-Type",
+ "PROGRAMMATIC_ACCESS_TOKEN",
+ )
+ .header("Accept", "application/json")
+ .header("Accept-Encoding", "gzip")
+ .body(AsyncBody::empty())?;
+
+ let response = http_client
+ .send(http_request)
+ .await
+ .context("failed to send partition request to Snowflake SQL API")?;
+
+ let status = response.status();
+ let content_encoding = response
+ .headers()
+ .get("content-encoding")
+ .and_then(|v| v.to_str().ok())
+ .map(|s| s.to_lowercase());
+
+ let body_bytes = {
+ use futures::AsyncReadExt as _;
+
+ let mut body = response.into_body();
+ let mut bytes = Vec::new();
+ body.read_to_end(&mut bytes)
+ .await
+ .context("failed to read Snowflake SQL API partition response body")?;
+ bytes
+ };
+
+ let body_bytes = if content_encoding.as_deref() == Some("gzip") {
+ let mut decoder = GzDecoder::new(&body_bytes[..]);
+ let mut decompressed = Vec::new();
+ decoder
+ .read_to_end(&mut decompressed)
+ .context("failed to decompress gzip response")?;
+ decompressed
+ } else {
+ body_bytes
+ };
+
+ if !status.is_success() && status.as_u16() != 202 {
+ let body_text = String::from_utf8_lossy(&body_bytes);
+ anyhow::bail!(
+ "snowflake sql api partition request http {}: {}",
+ status.as_u16(),
+ body_text
+ );
+ }
+
+ if body_bytes.is_empty() {
+ anyhow::bail!(
+ "snowflake sql api partition {} returned empty response body (http {})",
+ partition,
+ status.as_u16()
+ );
+ }
+
+ serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
+ let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
+ format!(
+ "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
+ partition,
+ status.as_u16(),
+ body_preview
+ )
+ })
+}
+
async fn run_sql(
http_client: Arc<dyn HttpClient>,
base_url: &str,
@@ -209,7 +406,7 @@ async fn run_sql(
bytes
};
- if !status.is_success() {
+ if !status.is_success() && status.as_u16() != 202 {
let body_text = String::from_utf8_lossy(&body_bytes);
anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
}