Make all paths relative in captured EP examples (#46213)

Max Brunsfeld created

I also added debug events for Sweep edit predictions so that we can see
the API response for "no prediction" events when running evals.

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/capture_example.rs | 118 +++++++++++++++-----
crates/edit_prediction/src/sweep_ai.rs        |  79 +++++++++++--
2 files changed, 152 insertions(+), 45 deletions(-)

Detailed changes

crates/edit_prediction/src/capture_example.rs 🔗

@@ -27,7 +27,8 @@ pub fn capture_example(
     let repository = project.read(cx).active_repository(cx)?;
     let repository_snapshot = repository.read(cx).snapshot();
     let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
-    let cursor_path = worktree.read(cx).root_name().join(file.path());
+    let root_name = worktree.read(cx).root_name_str().to_owned();
+    let cursor_path: Arc<Path> = file.path().as_std_path().into();
     if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
         return None;
     }
@@ -45,14 +46,9 @@ pub fn capture_example(
             collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 
         events.retain(|stored_event| {
-            match stored_event.event.as_ref() {
-                zeta_prompt::Event::BufferChange { path, .. } => {
-                    if !snapshots_by_path.contains_key(path) {
-                        return false;
-                    }
-                }
-            }
-            true
+            let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
+            let relative_path = strip_root_name(path, &root_name);
+            snapshots_by_path.contains_key(relative_path)
         });
 
         let line_comment_prefix = snapshot
@@ -71,7 +67,7 @@ pub fn capture_example(
 
         let mut edit_history = String::new();
         for stored_event in &events {
-            zeta_prompt::write_event(&mut edit_history, &stored_event.event);
+            write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
             if !edit_history.ends_with('\n') {
                 edit_history.push('\n');
             }
@@ -84,7 +80,7 @@ pub fn capture_example(
             tags: Vec::new(),
             reasoning: None,
             uncommitted_diff,
-            cursor_path: cursor_path.as_std_path().into(),
+            cursor_path,
             cursor_position: String::new(),
             edit_history,
             expected_patches: Vec::new(),
@@ -94,6 +90,37 @@ pub fn capture_example(
     }))
 }
 
+fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
+    path.strip_prefix(root_name).unwrap_or(path)
+}
+
+fn write_event_with_relative_paths(
+    output: &mut String,
+    event: &zeta_prompt::Event,
+    root_name: &str,
+) {
+    fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
+        for component in strip_root_name(path, root_name).components() {
+            output.push('/');
+            write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
+        }
+    }
+
+    let zeta_prompt::Event::BufferChange {
+        path,
+        old_path,
+        diff,
+        ..
+    } = event;
+
+    output.push_str("--- a");
+    write_relative_path(output, old_path.as_ref(), root_name);
+    output.push_str("\n+++ b");
+    write_relative_path(output, path.as_ref(), root_name);
+    output.push('\n');
+    output.push_str(diff);
+}
+
 fn compute_cursor_excerpt(
     snapshot: &language::BufferSnapshot,
     cursor_anchor: language::Anchor,
@@ -118,24 +145,16 @@ async fn collect_snapshots(
     cx: &mut gpui::AsyncApp,
 ) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
     let mut snapshots_by_path = HashMap::default();
-    let root_name = project.read_with(cx, |project, cx| {
-        project
-            .worktree_for_id(worktree_id, cx)
-            .unwrap()
-            .read(cx)
-            .root_name()
-            .to_owned()
-    })?;
     for stored_event in events {
         let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
-        if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
+        if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
             let project_path = project
                 .find_project_path(path, cx)
                 .filter(|path| path.worktree_id == worktree_id)?;
-            let full_path = root_name.join(&project_path.path).as_std_path().into();
-            Some((project_path, full_path))
+            let relative_path: Arc<Path> = project_path.path.as_std_path().into();
+            Some((project_path, relative_path))
         })? {
-            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
+            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
                 let buffer = project
                     .update(cx, |project, cx| {
                         project.open_buffer(project_path.clone(), cx)
@@ -158,11 +177,11 @@ fn compute_uncommitted_diff(
     snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
 ) -> String {
     let mut uncommitted_diff = String::new();
-    for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
+    for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
         if let Some(head_text) = &diff_snapshot.base_text_string() {
             let file_diff = language::unified_diff(head_text, &before_text.text());
             if !file_diff.is_empty() {
-                let path_str = full_path.to_string_lossy();
+                let path_str = relative_path.to_string_lossy();
                 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
                 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
                 uncommitted_diff.push_str(&file_diff);
@@ -257,6 +276,15 @@ mod tests {
         )
         .await;
 
+        // Create an external file outside the main project
+        fs.insert_tree(
+            "/external",
+            json!({
+                "external.rs": "fn external() {}\n",
+            }),
+        )
+        .await;
+
         fs.set_head_for_repo(
             Path::new("/project/.git"),
             &[("src/main.rs", committed_contents.to_string())],
@@ -312,9 +340,39 @@ mod tests {
         });
         cx.run_until_parked();
 
+        // Open and edit an external file (outside the main project's worktree)
+        let external_buffer = project
+            .update(cx, |project, cx| {
+                project.open_local_buffer("/external/external.rs", cx)
+            })
+            .await
+            .unwrap();
+        ep_store.update(cx, |ep_store, cx| {
+            ep_store.register_buffer(&external_buffer, &project, cx)
+        });
+        cx.run_until_parked();
+        external_buffer.update(cx, |buffer, cx| {
+            let point = Point::new(0, 0);
+            buffer.edit([(point..point, "// external edit\n")], None, cx);
+        });
+        cx.run_until_parked();
+
+        // Verify the external edit was recorded in events
         let events = ep_store.update(cx, |store, cx| {
             store.edit_history_for_project_with_pause_split_last_event(&project, cx)
         });
+        assert!(
+            matches!(
+                events
+                    .last()
+                    .unwrap()
+                    .event
+                    .as_ref(),
+                zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
+            ),
+            "external file edit should be in events"
+        );
+
         let mut example = cx
             .update(|cx| {
                 capture_example(project.clone(), buffer.clone(), Anchor::MIN, events, cx).unwrap()
@@ -332,8 +390,8 @@ mod tests {
                 tags: Vec::new(),
                 reasoning: None,
                 uncommitted_diff: indoc! {"
-                    --- a/project/src/main.rs
-                    +++ b/project/src/main.rs
+                    --- a/src/main.rs
+                    +++ b/src/main.rs
                     @@ -1,4 +1,5 @@
                      fn main() {
                     +    // comment 1
@@ -349,7 +407,7 @@ mod tests {
                      }
                 "}
                 .to_string(),
-                cursor_path: Path::new("project/src/main.rs").into(),
+                cursor_path: Path::new("src/main.rs").into(),
                 cursor_position: indoc! {"
                     fn main() {
                     ^[CURSOR_POSITION]
@@ -370,8 +428,8 @@ mod tests {
                 "}
                 .to_string(),
                 edit_history: indoc! {"
-                    --- a/project/src/main.rs
-                    +++ b/project/src/main.rs
+                    --- a/src/main.rs
+                    +++ b/src/main.rs
                     @@ -2,8 +2,10 @@
                          // comment 1
                          one();

crates/edit_prediction/src/sweep_ai.rs 🔗

@@ -1,11 +1,17 @@
-use anyhow::Result;
+use crate::{
+    CurrentEditPrediction, DebugEvent, EditPrediction, EditPredictionFinishedDebugEvent,
+    EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
+    EditPredictionStore, prediction::EditPredictionResult,
+};
+use anyhow::{Result, bail};
 use client::Client;
-use futures::AsyncReadExt as _;
+use edit_prediction_types::SuggestionDisplayType;
+use futures::{AsyncReadExt as _, channel::mpsc};
 use gpui::{
     App, AppContext as _, Entity, Global, SharedString, Task,
     http_client::{self, AsyncBody, Method},
 };
-use language::{Anchor, BufferSnapshot, Point, ToOffset as _};
+use language::{Anchor, Buffer, BufferSnapshot, Point, ToOffset as _};
 use language_model::{ApiKeyState, EnvVar, env_var};
 use lsp::DiagnosticSeverity;
 use serde::{Deserialize, Serialize};
@@ -17,12 +23,6 @@ use std::{
     time::Instant,
 };
 
-use crate::{
-    CurrentEditPrediction, EditPrediction, EditPredictionId, EditPredictionModelInput,
-    EditPredictionStore, prediction::EditPredictionResult,
-};
-use edit_prediction_types::SuggestionDisplayType;
-
 const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 const SWEEP_METRICS_URL: &str = "https://backend.app.sweep.dev/backend/track_autocomplete_metrics";
 
@@ -48,6 +48,10 @@ impl SweepAi {
         self.api_token.update(cx, |key_state, cx| {
             _ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
         });
+
+        let buffer = inputs.buffer.clone();
+        let debug_tx = inputs.debug_tx.clone();
+
         let Some(api_token) = self.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
             return Task::ready(Ok(None));
         };
@@ -195,12 +199,19 @@ impl SweepAi {
                 events: inputs.events,
                 related_files: inputs.related_files.clone(),
                 cursor_path: full_path.clone(),
-                cursor_excerpt: request_body.file_contents.into(),
+                cursor_excerpt: request_body.file_contents.clone().into(),
                 // we actually don't know
                 editable_range_in_excerpt: 0..inputs.snapshot.len(),
                 cursor_offset_in_excerpt: request_body.cursor_position,
             };
 
+            send_started_event(
+                &debug_tx,
+                &buffer,
+                inputs.position,
+                serde_json::to_string(&request_body).unwrap_or_default(),
+            );
+
             let request = http_client::Request::builder()
                 .uri(SWEEP_API_URL)
                 .header("Content-Type", "application/json")
@@ -212,19 +223,23 @@ impl SweepAi {
 
             let mut response = http_client.send(request).await?;
 
-            let mut body: Vec<u8> = Vec::new();
-            response.body_mut().read_to_end(&mut body).await?;
+            let mut body = String::new();
+            response.body_mut().read_to_string(&mut body).await?;
 
             let response_received_at = Instant::now();
             if !response.status().is_success() {
-                anyhow::bail!(
+                let message = format!(
                     "Request failed with status: {:?}\nBody: {}",
                     response.status(),
-                    String::from_utf8_lossy(&body),
+                    body,
                 );
+                send_finished_event(&debug_tx, &buffer, inputs.position, message.clone());
+                bail!(message);
             };
 
-            let response: AutocompleteResponse = serde_json::from_slice(&body)?;
+            let response: AutocompleteResponse = serde_json::from_str(&body)?;
+
+            send_finished_event(&debug_tx, &buffer, inputs.position, body);
 
             let old_text = inputs
                 .snapshot
@@ -275,6 +290,40 @@ impl SweepAi {
     }
 }
 
+fn send_started_event(
+    debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
+    buffer: &Entity<Buffer>,
+    position: Anchor,
+    prompt: String,
+) {
+    if let Some(debug_tx) = debug_tx {
+        _ = debug_tx.unbounded_send(DebugEvent::EditPredictionStarted(
+            EditPredictionStartedDebugEvent {
+                buffer: buffer.downgrade(),
+                position,
+                prompt: Some(prompt),
+            },
+        ));
+    }
+}
+
+fn send_finished_event(
+    debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
+    buffer: &Entity<Buffer>,
+    position: Anchor,
+    model_output: String,
+) {
+    if let Some(debug_tx) = debug_tx {
+        _ = debug_tx.unbounded_send(DebugEvent::EditPredictionFinished(
+            EditPredictionFinishedDebugEvent {
+                buffer: buffer.downgrade(),
+                position,
+                model_output: Some(model_output),
+            },
+        ));
+    }
+}
+
 pub const SWEEP_CREDENTIALS_URL: SharedString =
     SharedString::new_static("https://autocomplete.sweep.dev");
 pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";