1use crate::{StoredEvent, example_spec::ExampleSpec};
2use anyhow::Result;
3use buffer_diff::BufferDiffSnapshot;
4use collections::HashMap;
5use gpui::{App, Entity, Task};
6use language::Buffer;
7use project::{Project, WorktreeId};
8use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
9use text::{BufferSnapshot as TextBufferSnapshot, Point};
10
11pub fn capture_example(
12 project: Entity<Project>,
13 buffer: Entity<Buffer>,
14 cursor_anchor: language::Anchor,
15 mut events: Vec<StoredEvent>,
16 populate_expected_patch: bool,
17 cx: &mut App,
18) -> Option<Task<Result<ExampleSpec>>> {
19 let snapshot = buffer.read(cx).snapshot();
20 let file = snapshot.file()?;
21 let worktree_id = file.worktree_id(cx);
22 let repository = project.read(cx).active_repository(cx)?;
23 let repository_snapshot = repository.read(cx).snapshot();
24 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
25 let root_name = worktree.read(cx).root_name_str().to_owned();
26 let cursor_path: Arc<Path> = file.path().as_std_path().into();
27 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
28 return None;
29 }
30
31 let repository_url = repository_snapshot
32 .remote_origin_url
33 .clone()
34 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
35 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
36
37 let git_store = project.read(cx).git_store().clone();
38
39 Some(cx.spawn(async move |mut cx| {
40 let snapshots_by_path =
41 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
42
43 events.retain(|stored_event| {
44 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
45 let relative_path = strip_root_name(path, &root_name);
46 snapshots_by_path.contains_key(relative_path)
47 });
48
49 let line_comment_prefix = snapshot
50 .language()
51 .and_then(|lang| lang.config().line_comments.first())
52 .map(|s| s.to_string())
53 .unwrap_or_default();
54
55 let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
56 .background_executor()
57 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
58 .await;
59 let uncommitted_diff = cx
60 .background_executor()
61 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
62 .await;
63
64 let mut edit_history = String::new();
65 for stored_event in &events {
66 write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
67 if !edit_history.ends_with('\n') {
68 edit_history.push('\n');
69 }
70 }
71
72 // Initialize an empty patch with context lines, to make it easy
73 // to write the expected patch by hand.
74 let mut expected_patches = Vec::new();
75 let mut rejected_patch = None;
76 if populate_expected_patch {
77 let mut empty_patch = String::new();
78 let start_row = cursor_excerpt_range.start.row + 1;
79 let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
80 writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
81 writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
82 writeln!(
83 &mut empty_patch,
84 "@@ -{},{} +{},{} @@",
85 start_row, row_count, start_row, row_count,
86 )
87 .ok();
88 for line in cursor_excerpt.lines() {
89 writeln!(&mut empty_patch, " {}", line).ok();
90 }
91
92 expected_patches.push(empty_patch.clone());
93 rejected_patch = Some(empty_patch);
94 }
95
96 let mut spec = ExampleSpec {
97 name: generate_timestamp_name(),
98 repository_url,
99 revision,
100 tags: Vec::new(),
101 reasoning: None,
102 uncommitted_diff,
103 cursor_path,
104 cursor_position: String::new(),
105 edit_history,
106 expected_patches,
107 rejected_patch,
108 telemetry: None,
109 human_feedback: Vec::new(),
110 rating: None,
111 };
112 spec.set_cursor_excerpt(
113 &cursor_excerpt,
114 cursor_offset_in_excerpt,
115 &line_comment_prefix,
116 );
117 Ok(spec)
118 }))
119}
120
121fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
122 path.strip_prefix(root_name).unwrap_or(path)
123}
124
125fn write_event_with_relative_paths(
126 output: &mut String,
127 event: &zeta_prompt::Event,
128 root_name: &str,
129) {
130 fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
131 for component in strip_root_name(path, root_name).components() {
132 output.push('/');
133 write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
134 }
135 }
136
137 let zeta_prompt::Event::BufferChange {
138 path,
139 old_path,
140 diff,
141 ..
142 } = event;
143
144 output.push_str("--- a");
145 write_relative_path(output, old_path.as_ref(), root_name);
146 output.push_str("\n+++ b");
147 write_relative_path(output, path.as_ref(), root_name);
148 output.push('\n');
149 output.push_str(diff);
150}
151
152fn compute_cursor_excerpt(
153 snapshot: &language::BufferSnapshot,
154 cursor_anchor: language::Anchor,
155) -> (String, usize, Range<Point>) {
156 use text::ToOffset as _;
157 use text::ToPoint as _;
158
159 let cursor_offset = cursor_anchor.to_offset(snapshot);
160 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
161 crate::cursor_excerpt::compute_cursor_excerpt(snapshot, cursor_offset);
162 let syntax_ranges = crate::cursor_excerpt::compute_syntax_ranges(
163 snapshot,
164 cursor_offset,
165 &excerpt_offset_range,
166 );
167 let excerpt_text: String = snapshot.text_for_range(excerpt_point_range).collect();
168 let (_, context_range) = zeta_prompt::compute_editable_and_context_ranges(
169 &excerpt_text,
170 cursor_offset_in_excerpt,
171 &syntax_ranges,
172 100,
173 50,
174 );
175 let context_text = excerpt_text[context_range.clone()].to_string();
176 let cursor_in_context = cursor_offset_in_excerpt.saturating_sub(context_range.start);
177 let context_buffer_start =
178 (excerpt_offset_range.start + context_range.start).to_point(snapshot);
179 let context_buffer_end = (excerpt_offset_range.start + context_range.end).to_point(snapshot);
180 (
181 context_text,
182 cursor_in_context,
183 context_buffer_start..context_buffer_end,
184 )
185}
186
187async fn collect_snapshots(
188 project: &Entity<Project>,
189 git_store: &Entity<project::git_store::GitStore>,
190 worktree_id: WorktreeId,
191 events: &[StoredEvent],
192 cx: &mut gpui::AsyncApp,
193) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
194 let mut snapshots_by_path = HashMap::default();
195 for stored_event in events {
196 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
197 if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
198 let project_path = project
199 .find_project_path(path, cx)
200 .filter(|path| path.worktree_id == worktree_id)?;
201 let relative_path: Arc<Path> = project_path.path.as_std_path().into();
202 Some((project_path, relative_path))
203 }) {
204 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
205 let buffer = project
206 .update(cx, |project, cx| {
207 project.open_buffer(project_path.clone(), cx)
208 })
209 .await?;
210 let diff = git_store
211 .update(cx, |git_store, cx| {
212 git_store.open_uncommitted_diff(buffer.clone(), cx)
213 })
214 .await?;
215 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
216 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
217 }
218 }
219 }
220 Ok(snapshots_by_path)
221}
222
223fn compute_uncommitted_diff(
224 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
225) -> String {
226 let mut uncommitted_diff = String::new();
227 for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
228 if let Some(head_text) = &diff_snapshot.base_text_string() {
229 let file_diff = language::unified_diff(head_text, &before_text.text());
230 if !file_diff.is_empty() {
231 let path_str = relative_path.to_string_lossy();
232 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
233 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
234 uncommitted_diff.push_str(&file_diff);
235 if !uncommitted_diff.ends_with('\n') {
236 uncommitted_diff.push('\n');
237 }
238 }
239 }
240 }
241 uncommitted_diff
242}
243
244fn generate_timestamp_name() -> String {
245 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
246 match format {
247 Ok(format) => {
248 let now = time::OffsetDateTime::now_local()
249 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
250 now.format(&format)
251 .unwrap_or_else(|_| "unknown-time".to_string())
252 }
253 Err(_) => "unknown-time".to_string(),
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use crate::EditPredictionStore;
261 use client::{Client, UserStore};
262 use clock::FakeSystemClock;
263 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
264 use indoc::indoc;
265 use language::{Anchor, Point};
266 use project::{FakeFs, Project};
267 use serde_json::json;
268 use settings::SettingsStore;
269 use std::path::Path;
270
271 #[gpui::test]
272 async fn test_capture_example(cx: &mut TestAppContext) {
273 init_test(cx);
274 let fs = FakeFs::new(cx.executor());
275
276 let committed_contents = indoc! {"
277 fn main() {
278 one();
279 two();
280 three();
281 four();
282 five();
283 six();
284 seven();
285 eight();
286 nine();
287 }
288 "};
289
290 let disk_contents = indoc! {"
291 fn main() {
292 // comment 1
293 one();
294 two();
295 three();
296 four();
297 five();
298 six();
299 seven();
300 eight();
301 // comment 2
302 nine();
303 }
304 "};
305
306 fs.insert_tree(
307 "/project",
308 json!({
309 ".git": {},
310 "src": {
311 "main.rs": disk_contents,
312 }
313 }),
314 )
315 .await;
316
317 // Create an external file outside the main project
318 fs.insert_tree(
319 "/external",
320 json!({
321 "external.rs": "fn external() {}\n",
322 }),
323 )
324 .await;
325
326 fs.set_head_for_repo(
327 Path::new("/project/.git"),
328 &[("src/main.rs", committed_contents.to_string())],
329 "abc123def456",
330 );
331 fs.set_remote_for_repo(
332 Path::new("/project/.git"),
333 "origin",
334 "https://github.com/test/repo.git",
335 );
336
337 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
338
339 let buffer = project
340 .update(cx, |project, cx| {
341 project.open_local_buffer("/project/src/main.rs", cx)
342 })
343 .await
344 .unwrap();
345
346 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
347 ep_store.update(cx, |ep_store, cx| {
348 ep_store.register_buffer(&buffer, &project, cx)
349 });
350 cx.run_until_parked();
351
352 buffer.update(cx, |buffer, cx| {
353 let point = Point::new(6, 0);
354 buffer.edit([(point..point, " // comment 3\n")], None, cx);
355 let point = Point::new(4, 0);
356 buffer.edit([(point..point, " // comment 4\n")], None, cx);
357
358 pretty_assertions::assert_eq!(
359 buffer.text(),
360 indoc! {"
361 fn main() {
362 // comment 1
363 one();
364 two();
365 // comment 4
366 three();
367 four();
368 // comment 3
369 five();
370 six();
371 seven();
372 eight();
373 // comment 2
374 nine();
375 }
376 "}
377 );
378 });
379 cx.run_until_parked();
380
381 // Open and edit an external file (outside the main project's worktree)
382 let external_buffer = project
383 .update(cx, |project, cx| {
384 project.open_local_buffer("/external/external.rs", cx)
385 })
386 .await
387 .unwrap();
388 ep_store.update(cx, |ep_store, cx| {
389 ep_store.register_buffer(&external_buffer, &project, cx)
390 });
391 cx.run_until_parked();
392 external_buffer.update(cx, |buffer, cx| {
393 let point = Point::new(0, 0);
394 buffer.edit([(point..point, "// external edit\n")], None, cx);
395 });
396 cx.run_until_parked();
397
398 // Verify the external edit was recorded in events
399 let events = ep_store.update(cx, |store, cx| store.edit_history_for_project(&project, cx));
400 assert!(
401 matches!(
402 events
403 .last()
404 .unwrap()
405 .event
406 .as_ref(),
407 zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
408 ),
409 "external file edit should be in events"
410 );
411
412 let mut example = cx
413 .update(|cx| {
414 capture_example(
415 project.clone(),
416 buffer.clone(),
417 Anchor::MIN,
418 events,
419 true,
420 cx,
421 )
422 .unwrap()
423 })
424 .await
425 .unwrap();
426 example.name = "test".to_string();
427
428 pretty_assertions::assert_eq!(
429 example,
430 ExampleSpec {
431 name: "test".to_string(),
432 repository_url: "https://github.com/test/repo.git".to_string(),
433 revision: "abc123def456".to_string(),
434 tags: Vec::new(),
435 reasoning: None,
436 uncommitted_diff: indoc! {"
437 --- a/src/main.rs
438 +++ b/src/main.rs
439 @@ -1,4 +1,5 @@
440 fn main() {
441 + // comment 1
442 one();
443 two();
444 three();
445 @@ -7,5 +8,6 @@
446 six();
447 seven();
448 eight();
449 + // comment 2
450 nine();
451 }
452 "}
453 .to_string(),
454 cursor_path: Path::new("src/main.rs").into(),
455 cursor_position: indoc! {"
456 fn main() {
457 ^[CURSOR_POSITION]
458 // comment 1
459 one();
460 two();
461 // comment 4
462 three();
463 four();
464 // comment 3
465 five();
466 six();
467 seven();
468 eight();
469 // comment 2
470 nine();
471 }
472 "}
473 .to_string(),
474 edit_history: indoc! {"
475 --- a/src/main.rs
476 +++ b/src/main.rs
477 @@ -2,8 +2,10 @@
478 // comment 1
479 one();
480 two();
481 + // comment 4
482 three();
483 four();
484 + // comment 3
485 five();
486 six();
487 seven();
488 "}
489 .to_string(),
490 expected_patches: vec![
491 indoc! {"
492 --- a/src/main.rs
493 +++ b/src/main.rs
494 @@ -1,16 +1,16 @@
495 fn main() {
496 // comment 1
497 one();
498 two();
499 // comment 4
500 three();
501 four();
502 // comment 3
503 five();
504 six();
505 seven();
506 eight();
507 // comment 2
508 nine();
509 }
510 "}
511 .to_string()
512 ],
513 rejected_patch: Some(
514 indoc! {"
515 --- a/src/main.rs
516 +++ b/src/main.rs
517 @@ -1,16 +1,16 @@
518 fn main() {
519 // comment 1
520 one();
521 two();
522 // comment 4
523 three();
524 four();
525 // comment 3
526 five();
527 six();
528 seven();
529 eight();
530 // comment 2
531 nine();
532 }
533 "}
534 .to_string()
535 ),
536 telemetry: None,
537 human_feedback: Vec::new(),
538 rating: None,
539 }
540 );
541 }
542
543 fn init_test(cx: &mut TestAppContext) {
544 cx.update(|cx| {
545 let settings_store = SettingsStore::test(cx);
546 cx.set_global(settings_store);
547 zlog::init_test();
548 let http_client = FakeHttpClient::with_404_response();
549 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
550 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
551 language_model::init(user_store.clone(), client.clone(), cx);
552 EditPredictionStore::global(&client, &user_store, cx);
553 })
554 }
555}