Detailed changes
@@ -5212,6 +5212,7 @@ dependencies = [
"anyhow",
"arrayvec",
"brotli",
+ "buffer_diff",
"client",
"clock",
"cloud_api_types",
@@ -5249,7 +5250,9 @@ dependencies = [
"strum 0.27.2",
"telemetry",
"telemetry_events",
+ "text",
"thiserror 2.0.17",
+ "time",
"ui",
"util",
"uuid",
@@ -5354,8 +5357,10 @@ dependencies = [
"anyhow",
"buffer_diff",
"client",
+ "clock",
"cloud_llm_client",
"codestral",
+ "collections",
"command_palette_hooks",
"copilot",
"edit_prediction",
@@ -5364,18 +5369,20 @@ dependencies = [
"feature_flags",
"fs",
"futures 0.3.31",
- "git",
"gpui",
"indoc",
"language",
- "log",
+ "language_model",
"lsp",
"markdown",
"menu",
"multi_buffer",
"paths",
+ "pretty_assertions",
"project",
"regex",
+ "release_channel",
+ "semver",
"serde_json",
"settings",
"supermaven",
@@ -5388,6 +5395,7 @@ dependencies = [
"workspace",
"zed_actions",
"zeta_prompt",
+ "zlog",
]
[[package]]
@@ -314,6 +314,12 @@ impl BufferDiffSnapshot {
self.inner.hunks.is_empty()
}
+ pub fn base_text_string(&self) -> Option<String> {
+ self.inner
+ .base_text_exists
+ .then(|| self.inner.base_text.text())
+ }
+
pub fn secondary_diff(&self) -> Option<&BufferDiffSnapshot> {
self.secondary_diff.as_deref()
}
@@ -19,6 +19,7 @@ ai_onboarding.workspace = true
anyhow.workspace = true
arrayvec.workspace = true
brotli.workspace = true
+buffer_diff.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
@@ -52,7 +53,9 @@ settings.workspace = true
strum.workspace = true
telemetry.workspace = true
telemetry_events.workspace = true
+text.workspace = true
thiserror.workspace = true
+time.workspace = true
ui.workspace = true
util.workspace = true
uuid.workspace = true
@@ -0,0 +1,375 @@
+use crate::{
+ EditPredictionStore, StoredEvent,
+ cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
+};
+use anyhow::Result;
+use buffer_diff::BufferDiffSnapshot;
+use collections::HashMap;
+use gpui::{App, Entity, Task};
+use language::{Buffer, ToPoint as _};
+use project::Project;
+use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
+use text::{BufferSnapshot as TextBufferSnapshot, ToOffset as _};
+
+pub fn capture_example(
+ project: Entity<Project>,
+ buffer: Entity<Buffer>,
+ cursor_anchor: language::Anchor,
+ last_event_is_expected_patch: bool,
+ cx: &mut App,
+) -> Option<Task<Result<ExampleSpec>>> {
+ let ep_store = EditPredictionStore::try_global(cx)?;
+ let snapshot = buffer.read(cx).snapshot();
+ let file = snapshot.file()?;
+ let worktree_id = file.worktree_id(cx);
+ let repository = project.read(cx).active_repository(cx)?;
+ let repository_snapshot = repository.read(cx).snapshot();
+ let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
+ let cursor_path = worktree.read(cx).root_name().join(file.path());
+ if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
+ return None;
+ }
+
+ let repository_url = repository_snapshot
+ .remote_origin_url
+ .clone()
+ .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
+ let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
+
+ let mut events = ep_store.update(cx, |store, cx| {
+ store.edit_history_for_project_with_pause_split_last_event(&project, cx)
+ });
+
+ let git_store = project.read(cx).git_store().clone();
+
+ Some(cx.spawn(async move |mut cx| {
+ let snapshots_by_path = collect_snapshots(&project, &git_store, &events, &mut cx).await?;
+ let cursor_excerpt = cx
+ .background_executor()
+ .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
+ .await;
+ let uncommitted_diff = cx
+ .background_executor()
+ .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
+ .await;
+
+ let mut edit_history = String::new();
+ let mut expected_patch = String::new();
+ if last_event_is_expected_patch {
+ if let Some(stored_event) = events.pop() {
+ zeta_prompt::write_event(&mut expected_patch, &stored_event.event);
+ }
+ }
+
+ for stored_event in &events {
+ zeta_prompt::write_event(&mut edit_history, &stored_event.event);
+ if !edit_history.ends_with('\n') {
+ edit_history.push('\n');
+ }
+ }
+
+ let name = generate_timestamp_name();
+
+ Ok(ExampleSpec {
+ name,
+ repository_url,
+ revision,
+ uncommitted_diff,
+ cursor_path: cursor_path.as_std_path().into(),
+ cursor_position: cursor_excerpt,
+ edit_history,
+ expected_patch,
+ })
+ }))
+}
+
+fn compute_cursor_excerpt(
+ snapshot: &language::BufferSnapshot,
+ cursor_anchor: language::Anchor,
+) -> String {
+ let cursor_point = cursor_anchor.to_point(snapshot);
+ let (_editable_range, context_range) =
+ editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
+
+ let context_start_offset = context_range.start.to_offset(snapshot);
+ let cursor_offset = cursor_anchor.to_offset(snapshot);
+ let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
+ let mut excerpt = snapshot.text_for_range(context_range).collect::<String>();
+ if cursor_offset_in_excerpt <= excerpt.len() {
+ excerpt.insert_str(cursor_offset_in_excerpt, zeta_prompt::CURSOR_MARKER);
+ }
+ excerpt
+}
+
+async fn collect_snapshots(
+ project: &Entity<Project>,
+ git_store: &Entity<project::git_store::GitStore>,
+ events: &[StoredEvent],
+ cx: &mut gpui::AsyncApp,
+) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
+ let mut snapshots_by_path = HashMap::default();
+ for stored_event in events {
+ let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
+ if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
+ let project_path = project.find_project_path(path, cx)?;
+ let full_path = project
+ .worktree_for_id(project_path.worktree_id, cx)?
+ .read(cx)
+ .root_name()
+ .join(&project_path.path)
+ .as_std_path()
+ .into();
+ Some((project_path, full_path))
+ })? {
+ if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_buffer(project_path.clone(), cx)
+ })?
+ .await?;
+ let diff = git_store
+ .update(cx, |git_store, cx| {
+ git_store.open_uncommitted_diff(buffer.clone(), cx)
+ })?
+ .await?;
+ let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx))?;
+ entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
+ }
+ }
+ }
+ Ok(snapshots_by_path)
+}
+
+fn compute_uncommitted_diff(
+ snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
+) -> String {
+ let mut uncommitted_diff = String::new();
+ for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
+ if let Some(head_text) = &diff_snapshot.base_text_string() {
+ let file_diff = language::unified_diff(head_text, &before_text.text());
+ if !file_diff.is_empty() {
+ let path_str = full_path.to_string_lossy();
+ writeln!(uncommitted_diff, "--- a/{path_str}").ok();
+ writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
+ uncommitted_diff.push_str(&file_diff);
+ if !uncommitted_diff.ends_with('\n') {
+ uncommitted_diff.push('\n');
+ }
+ }
+ }
+ }
+ uncommitted_diff
+}
+
+fn generate_timestamp_name() -> String {
+ let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
+ match format {
+ Ok(format) => {
+ let now = time::OffsetDateTime::now_local()
+ .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
+ now.format(&format)
+ .unwrap_or_else(|_| "unknown-time".to_string())
+ }
+ Err(_) => "unknown-time".to_string(),
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use client::{Client, UserStore};
+ use clock::FakeSystemClock;
+ use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
+ use indoc::indoc;
+ use language::{Anchor, Point};
+ use project::{FakeFs, Project};
+ use serde_json::json;
+ use settings::SettingsStore;
+ use std::path::Path;
+
+ #[gpui::test]
+ async fn test_capture_example(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+
+ let committed_contents = indoc! {"
+ fn main() {
+ one();
+ two();
+ three();
+ four();
+ five();
+ six();
+ seven();
+ eight();
+ nine();
+ }
+ "};
+
+ let disk_contents = indoc! {"
+ fn main() {
+ // comment 1
+ one();
+ two();
+ three();
+ four();
+ five();
+ six();
+ seven();
+ eight();
+ // comment 2
+ nine();
+ }
+ "};
+
+ fs.insert_tree(
+ "/project",
+ json!({
+ ".git": {},
+ "src": {
+ "main.rs": disk_contents,
+ }
+ }),
+ )
+ .await;
+
+ fs.set_head_for_repo(
+ Path::new("/project/.git"),
+ &[("src/main.rs", committed_contents.to_string())],
+ "abc123def456",
+ );
+ fs.set_remote_for_repo(
+ Path::new("/project/.git"),
+ "origin",
+ "https://github.com/test/repo.git",
+ );
+
+ let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer("/project/src/main.rs", cx)
+ })
+ .await
+ .unwrap();
+
+ let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_buffer(&buffer, &project, cx)
+ });
+ cx.run_until_parked();
+
+ buffer.update(cx, |buffer, cx| {
+ let point = Point::new(6, 0);
+ buffer.edit([(point..point, " // comment 3\n")], None, cx);
+ let point = Point::new(4, 0);
+ buffer.edit([(point..point, " // comment 4\n")], None, cx);
+
+ pretty_assertions::assert_eq!(
+ buffer.text(),
+ indoc! {"
+ fn main() {
+ // comment 1
+ one();
+ two();
+ // comment 4
+ three();
+ four();
+ // comment 3
+ five();
+ six();
+ seven();
+ eight();
+ // comment 2
+ nine();
+ }
+ "}
+ );
+ });
+ cx.run_until_parked();
+
+ let mut example = cx
+ .update(|cx| {
+ capture_example(project.clone(), buffer.clone(), Anchor::MIN, false, cx).unwrap()
+ })
+ .await
+ .unwrap();
+ example.name = "test".to_string();
+
+ pretty_assertions::assert_eq!(
+ example,
+ ExampleSpec {
+ name: "test".to_string(),
+ repository_url: "https://github.com/test/repo.git".to_string(),
+ revision: "abc123def456".to_string(),
+ uncommitted_diff: indoc! {"
+ --- a/project/src/main.rs
+ +++ b/project/src/main.rs
+ @@ -1,4 +1,5 @@
+ fn main() {
+ + // comment 1
+ one();
+ two();
+ three();
+ @@ -7,5 +8,6 @@
+ six();
+ seven();
+ eight();
+ + // comment 2
+ nine();
+ }
+ "}
+ .to_string(),
+ cursor_path: Path::new("project/src/main.rs").into(),
+ cursor_position: indoc! {"
+ <|user_cursor|>fn main() {
+ // comment 1
+ one();
+ two();
+ // comment 4
+ three();
+ four();
+ // comment 3
+ five();
+ six();
+ seven();
+ eight();
+ // comment 2
+ nine();
+ }
+ "}
+ .to_string(),
+ edit_history: indoc! {"
+ --- a/project/src/main.rs
+ +++ b/project/src/main.rs
+ @@ -2,8 +2,10 @@
+ // comment 1
+ one();
+ two();
+ + // comment 4
+ three();
+ four();
+ + // comment 3
+ five();
+ six();
+ seven();
+ "}
+ .to_string(),
+ expected_patch: "".to_string(),
+ }
+ );
+ }
+
+ fn init_test(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ zlog::init_test();
+ let http_client = FakeHttpClient::with_404_response();
+ let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
+ language_model::init(client.clone(), cx);
+ let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+ EditPredictionStore::global(&client, &user_store, cx);
+ })
+ }
+}
@@ -35,6 +35,7 @@ use semver::Version;
use serde::de::DeserializeOwned;
use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
use std::collections::{VecDeque, hash_map};
+use text::Edit;
use workspace::Workspace;
use std::ops::Range;
@@ -57,9 +58,9 @@ pub mod open_ai_response;
mod prediction;
pub mod sweep_ai;
-#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
pub mod udiff;
+mod capture_example;
mod zed_edit_prediction_delegate;
pub mod zeta1;
pub mod zeta2;
@@ -74,6 +75,7 @@ pub use crate::prediction::EditPrediction;
pub use crate::prediction::EditPredictionId;
use crate::prediction::EditPredictionResult;
pub use crate::sweep_ai::SweepAi;
+pub use capture_example::capture_example;
pub use language_model::ApiKeyState;
pub use telemetry_events::EditPredictionRating;
pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
@@ -231,8 +233,15 @@ pub struct EditPredictionFinishedDebugEvent {
pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
+/// An event with associated metadata for reconstructing buffer state.
+#[derive(Clone)]
+pub struct StoredEvent {
+ pub event: Arc<zeta_prompt::Event>,
+ pub old_snapshot: TextBufferSnapshot,
+}
+
struct ProjectState {
- events: VecDeque<Arc<zeta_prompt::Event>>,
+ events: VecDeque<StoredEvent>,
last_event: Option<LastEvent>,
recent_paths: VecDeque<ProjectPath>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
@@ -248,7 +257,7 @@ struct ProjectState {
}
impl ProjectState {
- pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
+ pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
self.events
.iter()
.cloned()
@@ -260,7 +269,7 @@ impl ProjectState {
.collect()
}
- pub fn events_split_by_pause(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
+ pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
self.events
.iter()
.cloned()
@@ -415,7 +424,7 @@ impl LastEvent {
&self,
license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
cx: &App,
- ) -> Option<Arc<zeta_prompt::Event>> {
+ ) -> Option<StoredEvent> {
let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
@@ -430,19 +439,22 @@ impl LastEvent {
})
});
- let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
+ let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
if path == old_path && diff.is_empty() {
None
} else {
- Some(Arc::new(zeta_prompt::Event::BufferChange {
- old_path,
- path,
- diff,
- in_open_source_repo,
- // TODO: Actually detect if this edit was predicted or not
- predicted: false,
- }))
+ Some(StoredEvent {
+ event: Arc::new(zeta_prompt::Event::BufferChange {
+ old_path,
+ path,
+ diff,
+ in_open_source_repo,
+ // TODO: Actually detect if this edit was predicted or not
+ predicted: false,
+ }),
+ old_snapshot: self.old_snapshot.clone(),
+ })
}
}
@@ -475,6 +487,52 @@ impl LastEvent {
}
}
+pub(crate) fn compute_diff_between_snapshots(
+ old_snapshot: &TextBufferSnapshot,
+ new_snapshot: &TextBufferSnapshot,
+) -> Option<String> {
+ let edits: Vec<Edit<usize>> = new_snapshot
+ .edits_since::<usize>(&old_snapshot.version)
+ .collect();
+
+ let (first_edit, last_edit) = edits.first().zip(edits.last())?;
+
+ let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
+ let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
+ let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
+ let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
+
+ const CONTEXT_LINES: u32 = 3;
+
+ let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
+ let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
+ let old_context_end_row =
+ (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
+ let new_context_end_row =
+ (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
+
+ let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
+ let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
+ let old_end_line_offset = old_snapshot
+ .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
+ let new_end_line_offset = new_snapshot
+ .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
+ let old_edit_range = old_start_line_offset..old_end_line_offset;
+ let new_edit_range = new_start_line_offset..new_end_line_offset;
+
+ let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
+ let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
+
+ let diff = language::unified_diff_with_offsets(
+ &old_region_text,
+ &new_region_text,
+ old_context_start_row,
+ new_context_start_row,
+ );
+
+ Some(diff)
+}
+
fn buffer_path_with_id_fallback(
file: Option<&Arc<dyn File>>,
snapshot: &TextBufferSnapshot,
@@ -643,7 +701,7 @@ impl EditPredictionStore {
&self,
project: &Entity<Project>,
cx: &App,
- ) -> Vec<Arc<zeta_prompt::Event>> {
+ ) -> Vec<StoredEvent> {
self.projects
.get(&project.entity_id())
.map(|project_state| project_state.events(cx))
@@ -654,7 +712,7 @@ impl EditPredictionStore {
&self,
project: &Entity<Project>,
cx: &App,
- ) -> Vec<Arc<zeta_prompt::Event>> {
+ ) -> Vec<StoredEvent> {
self.projects
.get(&project.entity_id())
.map(|project_state| project_state.events_split_by_pause(cx))
@@ -1536,8 +1594,10 @@ impl EditPredictionStore {
self.get_or_init_project(&project, cx);
let project_state = self.projects.get(&project.entity_id()).unwrap();
- let events = project_state.events(cx);
- let has_events = !events.is_empty();
+ let stored_events = project_state.events(cx);
+ let has_events = !stored_events.is_empty();
+ let events: Vec<Arc<zeta_prompt::Event>> =
+ stored_events.into_iter().map(|e| e.event).collect();
let debug_tx = project_state.debug_tx.clone();
let snapshot = active_buffer.read(cx).snapshot();
@@ -1,5 +1,5 @@
use super::*;
-use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
+use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
use client::{UserStore, test::FakeServer};
use clock::{FakeSystemClock, ReplicaId};
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
@@ -360,7 +360,7 @@ async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContex
ep_store.edit_history_for_project(&project, cx)
});
assert_eq!(events.len(), 1);
- let zeta_prompt::Event::BufferChange { diff, .. } = events[0].as_ref();
+ let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
assert_eq!(
diff.as_str(),
indoc! {"
@@ -377,7 +377,7 @@ async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContex
ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
});
assert_eq!(events.len(), 2);
- let zeta_prompt::Event::BufferChange { diff, .. } = events[0].as_ref();
+ let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
assert_eq!(
diff.as_str(),
indoc! {"
@@ -389,7 +389,7 @@ async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContex
"}
);
- let zeta_prompt::Event::BufferChange { diff, .. } = events[1].as_ref();
+ let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
assert_eq!(
diff.as_str(),
indoc! {"
@@ -2082,6 +2082,74 @@ async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut Te
);
}
+#[gpui::test]
+fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
+ let buffer = cx.new(|cx| {
+ Buffer::local(
+ indoc! {"
+ zero
+ one
+ two
+ three
+ four
+ five
+ six
+ seven
+ eight
+ nine
+ ten
+ eleven
+ twelve
+ thirteen
+ fourteen
+ fifteen
+ sixteen
+ seventeen
+ eighteen
+ nineteen
+ twenty
+ twenty-one
+ twenty-two
+ twenty-three
+ twenty-four
+ "},
+ cx,
+ )
+ });
+
+ let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
+
+ buffer.update(cx, |buffer, cx| {
+ let point = Point::new(12, 0);
+ buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
+ let point = Point::new(8, 0);
+ buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
+ });
+
+ let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
+
+ let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
+
+ assert_eq!(
+ diff,
+ indoc! {"
+ @@ -6,10 +6,12 @@
+ five
+ six
+ seven
+ +FIRST INSERTION
+ eight
+ nine
+ ten
+ eleven
+ +SECOND INSERTION
+ twelve
+ thirteen
+ fourteen
+ "}
+ );
+}
+
#[ctor::ctor]
fn init_logger() {
zlog::init_test();
@@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize};
use std::{fmt::Write as _, mem, path::Path, sync::Arc};
-#[derive(Clone, Debug, Serialize, Deserialize)]
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ExampleSpec {
#[serde(default)]
pub name: String,
@@ -45,6 +45,11 @@ pub async fn run_format_prompt(
let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
let project = state.project.clone();
let (_, input) = ep_store.update(&mut cx, |ep_store, cx| {
+ let events = ep_store
+ .edit_history_for_project(&project, cx)
+ .into_iter()
+ .map(|e| e.event)
+ .collect();
anyhow::Ok(zeta2_prompt_input(
&snapshot,
example
@@ -53,7 +58,7 @@ pub async fn run_format_prompt(
.context("context must be set")?
.files
.clone(),
- ep_store.edit_history_for_project(&project, cx),
+ events,
example.spec.cursor_path.clone(),
example
.buffer
@@ -15,8 +15,7 @@ doctest = false
[dependencies]
anyhow.workspace = true
buffer_diff.workspace = true
-git.workspace = true
-log.workspace = true
+collections.workspace = true
time.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
@@ -50,11 +49,18 @@ zed_actions.workspace = true
zeta_prompt.workspace = true
[dev-dependencies]
+clock.workspace = true
copilot = { workspace = true, features = ["test-support"] }
editor = { workspace = true, features = ["test-support"] }
futures.workspace = true
indoc.workspace = true
+language_model.workspace = true
lsp = { workspace = true, features = ["test-support"] }
+pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
+release_channel.workspace = true
+semver.workspace = true
serde_json.workspace = true
theme = { workspace = true, features = ["test-support"] }
+workspace = { workspace = true, features = ["test-support"] }
+zlog.workspace = true
@@ -915,11 +915,8 @@ impl EditPredictionButton {
.when(
cx.has_flag::<PredictEditsRatePredictionsFeatureFlag>(),
|this| {
- this.action(
- "Capture Edit Prediction Example",
- CaptureExample.boxed_clone(),
- )
- .action("Rate Predictions", RatePredictions.boxed_clone())
+ this.action("Capture Prediction Example", CaptureExample.boxed_clone())
+ .action("Rate Predictions", RatePredictions.boxed_clone())
},
);
}
@@ -2,25 +2,17 @@ mod edit_prediction_button;
mod edit_prediction_context_view;
mod rate_prediction_modal;
-use std::any::{Any as _, TypeId};
-use std::path::Path;
-use std::sync::Arc;
-
use command_palette_hooks::CommandPaletteFilter;
-use edit_prediction::{
- EditPredictionStore, ResetOnboarding, Zeta2FeatureFlag, example_spec::ExampleSpec,
-};
+use edit_prediction::{ResetOnboarding, Zeta2FeatureFlag, capture_example};
use edit_prediction_context_view::EditPredictionContextView;
use editor::Editor;
use feature_flags::FeatureFlagAppExt as _;
-use git::repository::DiffType;
-use gpui::{Window, actions};
-use language::ToPoint as _;
-use log;
+use gpui::actions;
+use language::language_settings::AllLanguageSettings;
use project::DisableAiSettings;
use rate_prediction_modal::RatePredictionsModal;
use settings::{Settings as _, SettingsStore};
-use text::ToOffset as _;
+use std::any::{Any as _, TypeId};
use ui::{App, prelude::*};
use workspace::{SplitDirection, Workspace};
@@ -56,7 +48,9 @@ pub fn init(cx: &mut App) {
}
});
- workspace.register_action(capture_edit_prediction_example);
+ workspace.register_action(|workspace, _: &CaptureExample, window, cx| {
+ capture_example_as_markdown(workspace, window, cx);
+ });
workspace.register_action_renderer(|div, _, _, cx| {
let has_flag = cx.has_flag::<Zeta2FeatureFlag>();
div.when(has_flag, |div| {
@@ -138,182 +132,48 @@ fn feature_gate_predict_edits_actions(cx: &mut App) {
.detach();
}
-fn capture_edit_prediction_example(
+fn capture_example_as_markdown(
workspace: &mut Workspace,
- _: &CaptureExample,
window: &mut Window,
cx: &mut Context<Workspace>,
-) {
- let Some(ep_store) = EditPredictionStore::try_global(cx) else {
- return;
- };
+) -> Option<()> {
+ let markdown_language = workspace
+ .app_state()
+ .languages
+ .language_for_name("Markdown");
+ let fs = workspace.app_state().fs.clone();
let project = workspace.project().clone();
-
- let (worktree_root, repository) = {
- let project_ref = project.read(cx);
- let worktree_root = project_ref
- .visible_worktrees(cx)
- .next()
- .map(|worktree| worktree.read(cx).abs_path());
- let repository = project_ref.active_repository(cx);
- (worktree_root, repository)
- };
-
- let (Some(worktree_root), Some(repository)) = (worktree_root, repository) else {
- log::error!("CaptureExampleSpec: missing worktree or active repository");
- return;
- };
-
- let repository_snapshot = repository.read(cx).snapshot();
- if worktree_root.as_ref() != repository_snapshot.work_directory_abs_path.as_ref() {
- log::error!(
- "repository is not at worktree root (repo={:?}, worktree={:?})",
- repository_snapshot.work_directory_abs_path,
- worktree_root
- );
- return;
- }
-
- let Some(repository_url) = repository_snapshot
- .remote_origin_url
- .clone()
- .or_else(|| repository_snapshot.remote_upstream_url.clone())
- else {
- log::error!("active repository has no origin/upstream remote url");
- return;
- };
-
- let Some(revision) = repository_snapshot
- .head_commit
- .as_ref()
- .map(|commit| commit.sha.to_string())
- else {
- log::error!("active repository has no head commit");
- return;
- };
-
- let mut events = ep_store.update(cx, |store, cx| {
- store.edit_history_for_project_with_pause_split_last_event(&project, cx)
- });
-
- let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
- log::error!("no active editor");
- return;
- };
-
- let Some(project_path) = editor.read(cx).project_path(cx) else {
- log::error!("active editor has no project path");
- return;
- };
-
- let Some((buffer, cursor_anchor)) = editor
- .read(cx)
+ let editor = workspace.active_item_as::<Editor>(cx)?;
+ let editor = editor.read(cx);
+ let (buffer, cursor_anchor) = editor
.buffer()
.read(cx)
- .text_anchor_for_position(editor.read(cx).selections.newest_anchor().head(), cx)
- else {
- log::error!("failed to resolve cursor buffer/anchor");
- return;
- };
-
- let snapshot = buffer.read(cx).snapshot();
- let cursor_point = cursor_anchor.to_point(&snapshot);
- let (_editable_range, context_range) =
- edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
- cursor_point,
- &snapshot,
- 100,
- 50,
- );
-
- let cursor_path: Arc<Path> = repository
- .read(cx)
- .project_path_to_repo_path(&project_path, cx)
- .map(|repo_path| Path::new(repo_path.as_unix_str()).into())
- .unwrap_or_else(|| Path::new(project_path.path.as_unix_str()).into());
+ .text_anchor_for_position(editor.selections.newest_anchor().head(), cx)?;
+ let example = capture_example(project.clone(), buffer, cursor_anchor, true, cx)?;
- let cursor_position = {
- let context_start_offset = context_range.start.to_offset(&snapshot);
- let cursor_offset = cursor_anchor.to_offset(&snapshot);
- let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
- let mut excerpt = snapshot.text_for_range(context_range).collect::<String>();
- if cursor_offset_in_excerpt <= excerpt.len() {
- excerpt.insert_str(cursor_offset_in_excerpt, zeta_prompt::CURSOR_MARKER);
- }
- excerpt
- };
-
- let markdown_language = workspace
- .app_state()
- .languages
- .language_for_name("Markdown");
+ let examples_dir = AllLanguageSettings::get_global(cx)
+ .edit_predictions
+ .examples_dir
+ .clone();
cx.spawn_in(window, async move |workspace_entity, cx| {
let markdown_language = markdown_language.await?;
+ let example_spec = example.await?;
+ let buffer = if let Some(dir) = examples_dir {
+ fs.create_dir(&dir).await.ok();
+ let mut path = dir.join(&example_spec.name.replace(' ', "--").replace(':', "-"));
+ path.set_extension("md");
+ project.update(cx, |project, cx| project.open_local_buffer(&path, cx))
+ } else {
+ project.update(cx, |project, cx| project.create_buffer(false, cx))
+ }?
+ .await?;
- let uncommitted_diff_rx = repository.update(cx, |repository, cx| {
- repository.diff(DiffType::HeadToWorktree, cx)
- })?;
-
- let uncommitted_diff = match uncommitted_diff_rx.await {
- Ok(Ok(diff)) => diff,
- Ok(Err(error)) => {
- log::error!("failed to compute uncommitted diff: {error:#}");
- return Ok(());
- }
- Err(error) => {
- log::error!("uncommitted diff channel dropped: {error:#}");
- return Ok(());
- }
- };
-
- let mut edit_history = String::new();
- let mut expected_patch = String::new();
- if let Some(last_event) = events.pop() {
- for event in &events {
- zeta_prompt::write_event(&mut edit_history, event);
- if !edit_history.ends_with('\n') {
- edit_history.push('\n');
- }
- edit_history.push('\n');
- }
-
- zeta_prompt::write_event(&mut expected_patch, &last_event);
- }
-
- let format =
- time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
- let name = match format {
- Ok(format) => {
- let now = time::OffsetDateTime::now_local()
- .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
- now.format(&format)
- .unwrap_or_else(|_| "unknown-time".to_string())
- }
- Err(_) => "unknown-time".to_string(),
- };
-
- let markdown = ExampleSpec {
- name,
- repository_url,
- revision,
- uncommitted_diff,
- cursor_path,
- cursor_position,
- edit_history,
- expected_patch,
- }
- .to_markdown();
-
- let buffer = project
- .update(cx, |project, cx| project.create_buffer(false, cx))?
- .await?;
buffer.update(cx, |buffer, cx| {
- buffer.set_text(markdown, cx);
+ buffer.set_text(example_spec.to_markdown(), cx);
buffer.set_language(Some(markdown_language), cx);
})?;
-
workspace_entity.update_in(cx, |workspace, window, cx| {
workspace.add_item_to_active_pane(
Box::new(
@@ -327,4 +187,5 @@ fn capture_edit_prediction_example(
})
})
.detach_and_log_err(cx);
+ None
}
@@ -156,8 +156,16 @@ impl GitRepository for FakeGitRepository {
})
}
- fn remote_url(&self, _name: &str) -> BoxFuture<'_, Option<String>> {
- async move { None }.boxed()
+ fn remote_url(&self, name: &str) -> BoxFuture<'_, Option<String>> {
+ let name = name.to_string();
+ let fut = self.with_state_async(false, move |state| {
+ state
+ .remotes
+ .get(&name)
+ .context("remote not found")
+ .cloned()
+ });
+ async move { fut.await.ok() }.boxed()
}
fn diff_tree(&self, _request: DiffTreeType) -> BoxFuture<'_, Result<TreeDiff>> {
@@ -1857,6 +1857,18 @@ impl FakeFs {
.unwrap();
}
+ pub fn set_remote_for_repo(
+ &self,
+ dot_git: &Path,
+ name: impl Into<String>,
+ url: impl Into<String>,
+ ) {
+ self.with_git_state(dot_git, true, |state| {
+ state.remotes.insert(name.into(), url.into());
+ })
+ .unwrap();
+ }
+
pub fn insert_branches(&self, dot_git: &Path, branches: &[&str]) {
self.with_git_state(dot_git, true, |state| {
if let Some(first) = branches.first()
@@ -67,7 +67,7 @@ use task::RunnableTag;
pub use task_context::{ContextLocation, ContextProvider, RunnableRange};
pub use text_diff::{
DiffOptions, apply_diff_patch, line_diff, text_diff, text_diff_with_options, unified_diff,
- word_diff_ranges,
+ unified_diff_with_offsets, word_diff_ranges,
};
use theme::SyntaxTheme;
pub use toolchain::{
@@ -392,6 +392,7 @@ pub struct EditPredictionSettings {
/// Whether edit predictions are enabled in the assistant panel.
/// This setting has no effect if globally disabled.
pub enabled_in_text_threads: bool,
+ pub examples_dir: Option<Arc<Path>>,
}
impl EditPredictionSettings {
@@ -699,6 +700,7 @@ impl settings::Settings for AllLanguageSettings {
copilot: copilot_settings,
codestral: codestral_settings,
enabled_in_text_threads,
+ examples_dir: edit_predictions.examples_dir,
},
defaults: default_language_settings,
languages,
@@ -1,25 +1,139 @@
use crate::{CharClassifier, CharKind, CharScopeContext, LanguageScope};
use anyhow::{Context, anyhow};
use imara_diff::{
- Algorithm, UnifiedDiffBuilder, diff,
- intern::{InternedInput, Token},
+ Algorithm, Sink, diff,
+ intern::{InternedInput, Interner, Token},
sources::lines_with_terminator,
};
-use std::{iter, ops::Range, sync::Arc};
+use std::{fmt::Write, iter, ops::Range, sync::Arc};
const MAX_WORD_DIFF_LEN: usize = 512;
const MAX_WORD_DIFF_LINE_COUNT: usize = 8;
/// Computes a diff between two strings, returning a unified diff string.
pub fn unified_diff(old_text: &str, new_text: &str) -> String {
+ unified_diff_with_offsets(old_text, new_text, 0, 0)
+}
+
+/// Computes a diff between two strings, returning a unified diff string with
+/// hunk headers adjusted to reflect the given starting line numbers (1-indexed).
+pub fn unified_diff_with_offsets(
+ old_text: &str,
+ new_text: &str,
+ old_start_line: u32,
+ new_start_line: u32,
+) -> String {
let input = InternedInput::new(old_text, new_text);
diff(
Algorithm::Histogram,
&input,
- UnifiedDiffBuilder::new(&input),
+ OffsetUnifiedDiffBuilder::new(&input, old_start_line, new_start_line),
)
}
+/// A unified diff builder that applies line number offsets to hunk headers.
+struct OffsetUnifiedDiffBuilder<'a> {
+ before: &'a [Token],
+ after: &'a [Token],
+ interner: &'a Interner<&'a str>,
+
+ pos: u32,
+ before_hunk_start: u32,
+ after_hunk_start: u32,
+ before_hunk_len: u32,
+ after_hunk_len: u32,
+
+ old_line_offset: u32,
+ new_line_offset: u32,
+
+ buffer: String,
+ dst: String,
+}
+
+impl<'a> OffsetUnifiedDiffBuilder<'a> {
+ fn new(input: &'a InternedInput<&'a str>, old_line_offset: u32, new_line_offset: u32) -> Self {
+ Self {
+ before_hunk_start: 0,
+ after_hunk_start: 0,
+ before_hunk_len: 0,
+ after_hunk_len: 0,
+ old_line_offset,
+ new_line_offset,
+ buffer: String::with_capacity(8),
+ dst: String::new(),
+ interner: &input.interner,
+ before: &input.before,
+ after: &input.after,
+ pos: 0,
+ }
+ }
+
+ fn print_tokens(&mut self, tokens: &[Token], prefix: char) {
+ for &token in tokens {
+ writeln!(&mut self.buffer, "{prefix}{}", self.interner[token]).unwrap();
+ }
+ }
+
+ fn flush(&mut self) {
+ if self.before_hunk_len == 0 && self.after_hunk_len == 0 {
+ return;
+ }
+
+ let end = (self.pos + 3).min(self.before.len() as u32);
+ self.update_pos(end, end);
+
+ writeln!(
+ &mut self.dst,
+ "@@ -{},{} +{},{} @@",
+ self.before_hunk_start + 1 + self.old_line_offset,
+ self.before_hunk_len,
+ self.after_hunk_start + 1 + self.new_line_offset,
+ self.after_hunk_len,
+ )
+ .unwrap();
+ write!(&mut self.dst, "{}", &self.buffer).unwrap();
+ self.buffer.clear();
+ self.before_hunk_len = 0;
+ self.after_hunk_len = 0;
+ }
+
+ fn update_pos(&mut self, print_to: u32, move_to: u32) {
+ self.print_tokens(&self.before[self.pos as usize..print_to as usize], ' ');
+ let len = print_to - self.pos;
+ self.pos = move_to;
+ self.before_hunk_len += len;
+ self.after_hunk_len += len;
+ }
+}
+
+impl Sink for OffsetUnifiedDiffBuilder<'_> {
+ type Out = String;
+
+ fn process_change(&mut self, before: Range<u32>, after: Range<u32>) {
+ if before.start - self.pos > 6 {
+ self.flush();
+ }
+ if self.before_hunk_len == 0 && self.after_hunk_len == 0 {
+ self.pos = before.start.saturating_sub(3);
+ self.before_hunk_start = self.pos;
+ self.after_hunk_start = after.start.saturating_sub(3);
+ }
+ self.update_pos(before.start, before.end);
+ self.before_hunk_len += before.end - before.start;
+ self.after_hunk_len += after.end - after.start;
+ self.print_tokens(
+ &self.before[before.start as usize..before.end as usize],
+ '-',
+ );
+ self.print_tokens(&self.after[after.start as usize..after.end as usize], '+');
+ }
+
+ fn finish(mut self) -> Self::Out {
+ self.flush();
+ self.dst
+ }
+}
+
/// Computes a diff between two strings, returning a vector of old and new row
/// ranges.
pub fn line_diff(old_text: &str, new_text: &str) -> Vec<(Range<u32>, Range<u32>)> {
@@ -327,4 +441,30 @@ mod tests {
let patch = unified_diff(old_text, new_text);
assert_eq!(apply_diff_patch(old_text, &patch).unwrap(), new_text);
}
+
+ #[test]
+ fn test_unified_diff_with_offsets() {
+ let old_text = "foo\nbar\nbaz\n";
+ let new_text = "foo\nBAR\nbaz\n";
+
+ let expected_diff_body = " foo\n-bar\n+BAR\n baz\n";
+
+ let diff_no_offset = unified_diff(old_text, new_text);
+ assert_eq!(
+ diff_no_offset,
+ format!("@@ -1,3 +1,3 @@\n{}", expected_diff_body)
+ );
+
+ let diff_with_offset = unified_diff_with_offsets(old_text, new_text, 9, 11);
+ assert_eq!(
+ diff_with_offset,
+ format!("@@ -10,3 +12,3 @@\n{}", expected_diff_body)
+ );
+
+ let diff_with_offset = unified_diff_with_offsets(old_text, new_text, 99, 104);
+ assert_eq!(
+ diff_with_offset,
+ format!("@@ -100,3 +105,3 @@\n{}", expected_diff_body)
+ );
+ }
}
@@ -5756,6 +5756,7 @@ impl Repository {
cx.spawn(|_: &mut AsyncApp| async move { rx.await? })
}
+
fn load_blob_content(&mut self, oid: Oid, cx: &App) -> Task<Result<String>> {
let repository_id = self.snapshot.id;
let rx = self.send_job(None, move |state, _| async move {
@@ -56,6 +56,7 @@ merge_from_overwrites!(
std::sync::Arc<str>,
gpui::SharedString,
std::path::PathBuf,
+ std::sync::Arc<std::path::Path>,
gpui::Modifiers,
gpui::FontFeatures,
gpui::FontWeight
@@ -1,4 +1,4 @@
-use std::num::NonZeroU32;
+use std::{num::NonZeroU32, path::Path};
use collections::{HashMap, HashSet};
use gpui::{Modifiers, SharedString};
@@ -167,6 +167,8 @@ pub struct EditPredictionSettingsContent {
/// Whether edit predictions are enabled in the assistant prompt editor.
/// This has no effect if globally disabled.
pub enabled_in_text_threads: Option<bool>,
+ /// The directory where manually captured edit prediction examples are stored.
+ pub examples_dir: Option<Arc<Path>>,
}
#[with_fallible_options]