@@ -27,7 +27,8 @@ pub fn capture_example(
let repository = project.read(cx).active_repository(cx)?;
let repository_snapshot = repository.read(cx).snapshot();
let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
- let cursor_path = worktree.read(cx).root_name().join(file.path());
+ let root_name = worktree.read(cx).root_name_str().to_owned();
+ let cursor_path: Arc<Path> = file.path().as_std_path().into();
if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
return None;
}
@@ -45,14 +46,9 @@ pub fn capture_example(
collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
events.retain(|stored_event| {
- match stored_event.event.as_ref() {
- zeta_prompt::Event::BufferChange { path, .. } => {
- if !snapshots_by_path.contains_key(path) {
- return false;
- }
- }
- }
- true
+ let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
+ let relative_path = strip_root_name(path, &root_name);
+ snapshots_by_path.contains_key(relative_path)
});
let line_comment_prefix = snapshot
@@ -71,7 +67,7 @@ pub fn capture_example(
let mut edit_history = String::new();
for stored_event in &events {
- zeta_prompt::write_event(&mut edit_history, &stored_event.event);
+ write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
if !edit_history.ends_with('\n') {
edit_history.push('\n');
}
@@ -84,7 +80,7 @@ pub fn capture_example(
tags: Vec::new(),
reasoning: None,
uncommitted_diff,
- cursor_path: cursor_path.as_std_path().into(),
+ cursor_path,
cursor_position: String::new(),
edit_history,
expected_patches: Vec::new(),
@@ -94,6 +90,37 @@ pub fn capture_example(
}))
}
+fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
+ path.strip_prefix(root_name).unwrap_or(path)
+}
+
+fn write_event_with_relative_paths(
+ output: &mut String,
+ event: &zeta_prompt::Event,
+ root_name: &str,
+) {
+ fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
+ for component in strip_root_name(path, root_name).components() {
+ output.push('/');
+ write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
+ }
+ }
+
+ let zeta_prompt::Event::BufferChange {
+ path,
+ old_path,
+ diff,
+ ..
+ } = event;
+
+ output.push_str("--- a");
+ write_relative_path(output, old_path.as_ref(), root_name);
+ output.push_str("\n+++ b");
+ write_relative_path(output, path.as_ref(), root_name);
+ output.push('\n');
+ output.push_str(diff);
+}
+
fn compute_cursor_excerpt(
snapshot: &language::BufferSnapshot,
cursor_anchor: language::Anchor,
@@ -118,24 +145,16 @@ async fn collect_snapshots(
cx: &mut gpui::AsyncApp,
) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
let mut snapshots_by_path = HashMap::default();
- let root_name = project.read_with(cx, |project, cx| {
- project
- .worktree_for_id(worktree_id, cx)
- .unwrap()
- .read(cx)
- .root_name()
- .to_owned()
- })?;
for stored_event in events {
let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
- if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
+ if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
let project_path = project
.find_project_path(path, cx)
.filter(|path| path.worktree_id == worktree_id)?;
- let full_path = root_name.join(&project_path.path).as_std_path().into();
- Some((project_path, full_path))
+ let relative_path: Arc<Path> = project_path.path.as_std_path().into();
+ Some((project_path, relative_path))
})? {
- if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
+ if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
let buffer = project
.update(cx, |project, cx| {
project.open_buffer(project_path.clone(), cx)
@@ -158,11 +177,11 @@ fn compute_uncommitted_diff(
snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
) -> String {
let mut uncommitted_diff = String::new();
- for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
+ for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
if let Some(head_text) = &diff_snapshot.base_text_string() {
let file_diff = language::unified_diff(head_text, &before_text.text());
if !file_diff.is_empty() {
- let path_str = full_path.to_string_lossy();
+ let path_str = relative_path.to_string_lossy();
writeln!(uncommitted_diff, "--- a/{path_str}").ok();
writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
uncommitted_diff.push_str(&file_diff);
@@ -257,6 +276,15 @@ mod tests {
)
.await;
+ // Create an external file outside the main project
+ fs.insert_tree(
+ "/external",
+ json!({
+ "external.rs": "fn external() {}\n",
+ }),
+ )
+ .await;
+
fs.set_head_for_repo(
Path::new("/project/.git"),
&[("src/main.rs", committed_contents.to_string())],
@@ -312,9 +340,39 @@ mod tests {
});
cx.run_until_parked();
+ // Open and edit an external file (outside the main project's worktree)
+ let external_buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer("/external/external.rs", cx)
+ })
+ .await
+ .unwrap();
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_buffer(&external_buffer, &project, cx)
+ });
+ cx.run_until_parked();
+ external_buffer.update(cx, |buffer, cx| {
+ let point = Point::new(0, 0);
+ buffer.edit([(point..point, "// external edit\n")], None, cx);
+ });
+ cx.run_until_parked();
+
+ // Verify the external edit was recorded in events
let events = ep_store.update(cx, |store, cx| {
store.edit_history_for_project_with_pause_split_last_event(&project, cx)
});
+ assert!(
+ matches!(
+ events
+ .last()
+ .unwrap()
+ .event
+ .as_ref(),
+ zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
+ ),
+ "external file edit should be in events"
+ );
+
let mut example = cx
.update(|cx| {
capture_example(project.clone(), buffer.clone(), Anchor::MIN, events, cx).unwrap()
@@ -332,8 +390,8 @@ mod tests {
tags: Vec::new(),
reasoning: None,
uncommitted_diff: indoc! {"
- --- a/project/src/main.rs
- +++ b/project/src/main.rs
+ --- a/src/main.rs
+ +++ b/src/main.rs
@@ -1,4 +1,5 @@
fn main() {
+ // comment 1
@@ -349,7 +407,7 @@ mod tests {
}
"}
.to_string(),
- cursor_path: Path::new("project/src/main.rs").into(),
+ cursor_path: Path::new("src/main.rs").into(),
cursor_position: indoc! {"
fn main() {
^[CURSOR_POSITION]
@@ -370,8 +428,8 @@ mod tests {
"}
.to_string(),
edit_history: indoc! {"
- --- a/project/src/main.rs
- +++ b/project/src/main.rs
+ --- a/src/main.rs
+ +++ b/src/main.rs
@@ -2,8 +2,10 @@
// comment 1
one();
@@ -1,11 +1,17 @@
-use anyhow::Result;
+use crate::{
+ CurrentEditPrediction, DebugEvent, EditPrediction, EditPredictionFinishedDebugEvent,
+ EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
+ EditPredictionStore, prediction::EditPredictionResult,
+};
+use anyhow::{Result, bail};
use client::Client;
-use futures::AsyncReadExt as _;
+use edit_prediction_types::SuggestionDisplayType;
+use futures::{AsyncReadExt as _, channel::mpsc};
use gpui::{
App, AppContext as _, Entity, Global, SharedString, Task,
http_client::{self, AsyncBody, Method},
};
-use language::{Anchor, BufferSnapshot, Point, ToOffset as _};
+use language::{Anchor, Buffer, BufferSnapshot, Point, ToOffset as _};
use language_model::{ApiKeyState, EnvVar, env_var};
use lsp::DiagnosticSeverity;
use serde::{Deserialize, Serialize};
@@ -17,12 +23,6 @@ use std::{
time::Instant,
};
-use crate::{
- CurrentEditPrediction, EditPrediction, EditPredictionId, EditPredictionModelInput,
- EditPredictionStore, prediction::EditPredictionResult,
-};
-use edit_prediction_types::SuggestionDisplayType;
-
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
const SWEEP_METRICS_URL: &str = "https://backend.app.sweep.dev/backend/track_autocomplete_metrics";
@@ -48,6 +48,10 @@ impl SweepAi {
self.api_token.update(cx, |key_state, cx| {
_ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
});
+
+ let buffer = inputs.buffer.clone();
+ let debug_tx = inputs.debug_tx.clone();
+
let Some(api_token) = self.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
return Task::ready(Ok(None));
};
@@ -195,12 +199,19 @@ impl SweepAi {
events: inputs.events,
related_files: inputs.related_files.clone(),
cursor_path: full_path.clone(),
- cursor_excerpt: request_body.file_contents.into(),
+ cursor_excerpt: request_body.file_contents.clone().into(),
// we actually don't know
editable_range_in_excerpt: 0..inputs.snapshot.len(),
cursor_offset_in_excerpt: request_body.cursor_position,
};
+ send_started_event(
+ &debug_tx,
+ &buffer,
+ inputs.position,
+ serde_json::to_string(&request_body).unwrap_or_default(),
+ );
+
let request = http_client::Request::builder()
.uri(SWEEP_API_URL)
.header("Content-Type", "application/json")
@@ -212,19 +223,23 @@ impl SweepAi {
let mut response = http_client.send(request).await?;
- let mut body: Vec<u8> = Vec::new();
- response.body_mut().read_to_end(&mut body).await?;
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
let response_received_at = Instant::now();
if !response.status().is_success() {
- anyhow::bail!(
+ let message = format!(
"Request failed with status: {:?}\nBody: {}",
response.status(),
- String::from_utf8_lossy(&body),
+ body,
);
+ send_finished_event(&debug_tx, &buffer, inputs.position, message.clone());
+ bail!(message);
};
- let response: AutocompleteResponse = serde_json::from_slice(&body)?;
+ let response: AutocompleteResponse = serde_json::from_str(&body)?;
+
+ send_finished_event(&debug_tx, &buffer, inputs.position, body);
let old_text = inputs
.snapshot
@@ -275,6 +290,40 @@ impl SweepAi {
}
}
+fn send_started_event(
+ debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
+ buffer: &Entity<Buffer>,
+ position: Anchor,
+ prompt: String,
+) {
+ if let Some(debug_tx) = debug_tx {
+ _ = debug_tx.unbounded_send(DebugEvent::EditPredictionStarted(
+ EditPredictionStartedDebugEvent {
+ buffer: buffer.downgrade(),
+ position,
+ prompt: Some(prompt),
+ },
+ ));
+ }
+}
+
+fn send_finished_event(
+ debug_tx: &Option<mpsc::UnboundedSender<DebugEvent>>,
+ buffer: &Entity<Buffer>,
+ position: Anchor,
+ model_output: String,
+) {
+ if let Some(debug_tx) = debug_tx {
+ _ = debug_tx.unbounded_send(DebugEvent::EditPredictionFinished(
+ EditPredictionFinishedDebugEvent {
+ buffer: buffer.downgrade(),
+ position,
+ model_output: Some(model_output),
+ },
+ ));
+ }
+}
+
pub const SWEEP_CREDENTIALS_URL: SharedString =
SharedString::new_static("https://autocomplete.sweep.dev");
pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";