1use crate::{
2 StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position,
4 example_spec::{
5 CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile,
6 ExampleSpec, MAX_CURSOR_FILE_SIZE,
7 },
8};
9use anyhow::Result;
10use buffer_diff::BufferDiffSnapshot;
11use collections::HashMap;
12use gpui::{App, Entity, Task};
13use language::{Buffer, ToPoint as _};
14use project::{Project, WorktreeId};
15use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
16use text::{BufferSnapshot as TextBufferSnapshot, Point, ToOffset as _};
17
18pub fn capture_example(
19 project: Entity<Project>,
20 buffer: Entity<Buffer>,
21 cursor_anchor: language::Anchor,
22 mut events: Vec<StoredEvent>,
23 related_files: Vec<zeta_prompt::RelatedFile>,
24 populate_expected_patch: bool,
25 cx: &mut App,
26) -> Option<Task<Result<ExampleSpec>>> {
27 let snapshot = buffer.read(cx).snapshot();
28 let file = snapshot.file()?;
29 let worktree_id = file.worktree_id(cx);
30 let repository = project.read(cx).active_repository(cx)?;
31 let repository_snapshot = repository.read(cx).snapshot();
32 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
33 let root_name = worktree.read(cx).root_name_str().to_owned();
34 let cursor_path: Arc<Path> = file.path().as_std_path().into();
35 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
36 return None;
37 }
38
39 let repository_url = repository_snapshot
40 .remote_origin_url
41 .clone()
42 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
43 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
44
45 let git_store = project.read(cx).git_store().clone();
46
47 Some(cx.spawn(async move |mut cx| {
48 let snapshots_by_path =
49 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
50
51 events.retain(|stored_event| {
52 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
53 let relative_path = strip_root_name(path, &root_name);
54 snapshots_by_path.contains_key(relative_path)
55 });
56
57 let line_comment_prefix = snapshot
58 .language()
59 .and_then(|lang| lang.config().line_comments.first())
60 .map(|s| s.to_string())
61 .unwrap_or_default();
62
63 let full_cursor_offset = cursor_anchor.to_offset(&snapshot);
64 let cursor_point = cursor_anchor.to_point(&snapshot);
65 let cursor_file_content = if snapshot.len() <= MAX_CURSOR_FILE_SIZE {
66 Some(snapshot.text())
67 } else {
68 None
69 };
70
71 let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
72 .background_executor()
73 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
74 .await;
75 let uncommitted_diff = cx
76 .background_executor()
77 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
78 .await;
79
80 let mut edit_history = String::new();
81 for stored_event in &events {
82 write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
83 if !edit_history.ends_with('\n') {
84 edit_history.push('\n');
85 }
86 }
87
88 // Initialize an empty patch with context lines, to make it easy
89 // to write the expected patch by hand.
90 let mut expected_patches = Vec::new();
91 let mut rejected_patch = None;
92 if populate_expected_patch {
93 let mut empty_patch = String::new();
94 let start_row = cursor_excerpt_range.start.row + 1;
95 let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
96 writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
97 writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
98 writeln!(
99 &mut empty_patch,
100 "@@ -{},{} +{},{} @@",
101 start_row, row_count, start_row, row_count,
102 )
103 .ok();
104 for line in cursor_excerpt.lines() {
105 writeln!(&mut empty_patch, " {}", line).ok();
106 }
107
108 expected_patches.push(empty_patch.clone());
109 rejected_patch = Some(empty_patch);
110 }
111
112 let prompt_input = cursor_file_content.map(|content| {
113 let captured_events: Vec<CapturedEvent> = events
114 .iter()
115 .map(|stored_event| {
116 let zeta_prompt::Event::BufferChange {
117 path,
118 old_path,
119 diff,
120 predicted,
121 in_open_source_repo,
122 } = stored_event.event.as_ref();
123 CapturedEvent {
124 path: strip_root_name(path, &root_name).into(),
125 old_path: strip_root_name(old_path, &root_name).into(),
126 diff: diff.clone(),
127 predicted: *predicted,
128 in_open_source_repo: *in_open_source_repo,
129 }
130 })
131 .collect();
132
133 let captured_related_files: Vec<CapturedRelatedFile> = related_files
134 .iter()
135 .map(|rf| CapturedRelatedFile {
136 path: strip_root_name(&rf.path, &root_name).into(),
137 max_row: rf.max_row,
138 excerpts: rf
139 .excerpts
140 .iter()
141 .map(|e| CapturedRelatedExcerpt {
142 row_range: e.row_range.clone(),
143 text: e.text.to_string(),
144 })
145 .collect(),
146 })
147 .collect();
148
149 CapturedPromptInput {
150 cursor_file_content: content,
151 cursor_offset: full_cursor_offset,
152 cursor_row: cursor_point.row,
153 cursor_column: cursor_point.column,
154 excerpt_start_row: Some(0),
155 events: captured_events,
156 related_files: captured_related_files,
157 in_open_source_repo: false,
158 }
159 });
160
161 let mut spec = ExampleSpec {
162 name: generate_timestamp_name(),
163 repository_url,
164 revision,
165 tags: Vec::new(),
166 reasoning: None,
167 uncommitted_diff,
168 cursor_path,
169 cursor_position: String::new(),
170 edit_history,
171 expected_patches,
172 rejected_patch,
173 captured_prompt_input: prompt_input,
174 telemetry: None,
175 human_feedback: Vec::new(),
176 rating: None,
177 };
178 spec.set_cursor_excerpt(
179 &cursor_excerpt,
180 cursor_offset_in_excerpt,
181 &line_comment_prefix,
182 );
183 Ok(spec)
184 }))
185}
186
187fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
188 path.strip_prefix(root_name).unwrap_or(path)
189}
190
191fn write_event_with_relative_paths(
192 output: &mut String,
193 event: &zeta_prompt::Event,
194 root_name: &str,
195) {
196 fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
197 for component in strip_root_name(path, root_name).components() {
198 output.push('/');
199 write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
200 }
201 }
202
203 let zeta_prompt::Event::BufferChange {
204 path,
205 old_path,
206 diff,
207 ..
208 } = event;
209
210 output.push_str("--- a");
211 write_relative_path(output, old_path.as_ref(), root_name);
212 output.push_str("\n+++ b");
213 write_relative_path(output, path.as_ref(), root_name);
214 output.push('\n');
215 output.push_str(diff);
216}
217
218fn compute_cursor_excerpt(
219 snapshot: &language::BufferSnapshot,
220 cursor_anchor: language::Anchor,
221) -> (String, usize, Range<Point>) {
222 use text::ToOffset as _;
223
224 let cursor_point = cursor_anchor.to_point(snapshot);
225 let (_editable_range, context_range) =
226 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
227 let context_start_offset = context_range.start.to_offset(snapshot);
228 let cursor_offset = cursor_anchor.to_offset(snapshot);
229 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
230 let excerpt = snapshot
231 .text_for_range(context_range.clone())
232 .collect::<String>();
233 (excerpt, cursor_offset_in_excerpt, context_range)
234}
235
236async fn collect_snapshots(
237 project: &Entity<Project>,
238 git_store: &Entity<project::git_store::GitStore>,
239 worktree_id: WorktreeId,
240 events: &[StoredEvent],
241 cx: &mut gpui::AsyncApp,
242) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
243 let mut snapshots_by_path = HashMap::default();
244 for stored_event in events {
245 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
246 if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
247 let project_path = project
248 .find_project_path(path, cx)
249 .filter(|path| path.worktree_id == worktree_id)?;
250 let relative_path: Arc<Path> = project_path.path.as_std_path().into();
251 Some((project_path, relative_path))
252 }) {
253 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
254 let buffer = project
255 .update(cx, |project, cx| {
256 project.open_buffer(project_path.clone(), cx)
257 })
258 .await?;
259 let diff = git_store
260 .update(cx, |git_store, cx| {
261 git_store.open_uncommitted_diff(buffer.clone(), cx)
262 })
263 .await?;
264 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
265 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
266 }
267 }
268 }
269 Ok(snapshots_by_path)
270}
271
272fn compute_uncommitted_diff(
273 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
274) -> String {
275 let mut uncommitted_diff = String::new();
276 for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
277 if let Some(head_text) = &diff_snapshot.base_text_string() {
278 let file_diff = language::unified_diff(head_text, &before_text.text());
279 if !file_diff.is_empty() {
280 let path_str = relative_path.to_string_lossy();
281 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
282 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
283 uncommitted_diff.push_str(&file_diff);
284 if !uncommitted_diff.ends_with('\n') {
285 uncommitted_diff.push('\n');
286 }
287 }
288 }
289 }
290 uncommitted_diff
291}
292
293fn generate_timestamp_name() -> String {
294 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
295 match format {
296 Ok(format) => {
297 let now = time::OffsetDateTime::now_local()
298 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
299 now.format(&format)
300 .unwrap_or_else(|_| "unknown-time".to_string())
301 }
302 Err(_) => "unknown-time".to_string(),
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use crate::EditPredictionStore;
310 use client::{Client, UserStore};
311 use clock::FakeSystemClock;
312 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
313 use indoc::indoc;
314 use language::{Anchor, Point};
315 use project::{FakeFs, Project};
316 use serde_json::json;
317 use settings::SettingsStore;
318 use std::path::Path;
319
320 #[gpui::test]
321 async fn test_capture_example(cx: &mut TestAppContext) {
322 init_test(cx);
323 let fs = FakeFs::new(cx.executor());
324
325 let committed_contents = indoc! {"
326 fn main() {
327 one();
328 two();
329 three();
330 four();
331 five();
332 six();
333 seven();
334 eight();
335 nine();
336 }
337 "};
338
339 let disk_contents = indoc! {"
340 fn main() {
341 // comment 1
342 one();
343 two();
344 three();
345 four();
346 five();
347 six();
348 seven();
349 eight();
350 // comment 2
351 nine();
352 }
353 "};
354
355 fs.insert_tree(
356 "/project",
357 json!({
358 ".git": {},
359 "src": {
360 "main.rs": disk_contents,
361 }
362 }),
363 )
364 .await;
365
366 // Create an external file outside the main project
367 fs.insert_tree(
368 "/external",
369 json!({
370 "external.rs": "fn external() {}\n",
371 }),
372 )
373 .await;
374
375 fs.set_head_for_repo(
376 Path::new("/project/.git"),
377 &[("src/main.rs", committed_contents.to_string())],
378 "abc123def456",
379 );
380 fs.set_remote_for_repo(
381 Path::new("/project/.git"),
382 "origin",
383 "https://github.com/test/repo.git",
384 );
385
386 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
387
388 let buffer = project
389 .update(cx, |project, cx| {
390 project.open_local_buffer("/project/src/main.rs", cx)
391 })
392 .await
393 .unwrap();
394
395 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
396 ep_store.update(cx, |ep_store, cx| {
397 ep_store.register_buffer(&buffer, &project, cx)
398 });
399 cx.run_until_parked();
400
401 buffer.update(cx, |buffer, cx| {
402 let point = Point::new(6, 0);
403 buffer.edit([(point..point, " // comment 3\n")], None, cx);
404 let point = Point::new(4, 0);
405 buffer.edit([(point..point, " // comment 4\n")], None, cx);
406
407 pretty_assertions::assert_eq!(
408 buffer.text(),
409 indoc! {"
410 fn main() {
411 // comment 1
412 one();
413 two();
414 // comment 4
415 three();
416 four();
417 // comment 3
418 five();
419 six();
420 seven();
421 eight();
422 // comment 2
423 nine();
424 }
425 "}
426 );
427 });
428 cx.run_until_parked();
429
430 // Open and edit an external file (outside the main project's worktree)
431 let external_buffer = project
432 .update(cx, |project, cx| {
433 project.open_local_buffer("/external/external.rs", cx)
434 })
435 .await
436 .unwrap();
437 ep_store.update(cx, |ep_store, cx| {
438 ep_store.register_buffer(&external_buffer, &project, cx)
439 });
440 cx.run_until_parked();
441 external_buffer.update(cx, |buffer, cx| {
442 let point = Point::new(0, 0);
443 buffer.edit([(point..point, "// external edit\n")], None, cx);
444 });
445 cx.run_until_parked();
446
447 // Verify the external edit was recorded in events
448 let events = ep_store.update(cx, |store, cx| {
449 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
450 });
451 assert!(
452 matches!(
453 events
454 .last()
455 .unwrap()
456 .event
457 .as_ref(),
458 zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
459 ),
460 "external file edit should be in events"
461 );
462
463 let mut example = cx
464 .update(|cx| {
465 capture_example(
466 project.clone(),
467 buffer.clone(),
468 Anchor::MIN,
469 events,
470 Vec::new(),
471 true,
472 cx,
473 )
474 .unwrap()
475 })
476 .await
477 .unwrap();
478 example.name = "test".to_string();
479
480 pretty_assertions::assert_eq!(
481 example,
482 ExampleSpec {
483 name: "test".to_string(),
484 repository_url: "https://github.com/test/repo.git".to_string(),
485 revision: "abc123def456".to_string(),
486 tags: Vec::new(),
487 reasoning: None,
488 uncommitted_diff: indoc! {"
489 --- a/src/main.rs
490 +++ b/src/main.rs
491 @@ -1,4 +1,5 @@
492 fn main() {
493 + // comment 1
494 one();
495 two();
496 three();
497 @@ -7,5 +8,6 @@
498 six();
499 seven();
500 eight();
501 + // comment 2
502 nine();
503 }
504 "}
505 .to_string(),
506 cursor_path: Path::new("src/main.rs").into(),
507 cursor_position: indoc! {"
508 fn main() {
509 ^[CURSOR_POSITION]
510 // comment 1
511 one();
512 two();
513 // comment 4
514 three();
515 four();
516 // comment 3
517 five();
518 six();
519 seven();
520 eight();
521 // comment 2
522 nine();
523 }
524 "}
525 .to_string(),
526 edit_history: indoc! {"
527 --- a/src/main.rs
528 +++ b/src/main.rs
529 @@ -2,8 +2,10 @@
530 // comment 1
531 one();
532 two();
533 + // comment 4
534 three();
535 four();
536 + // comment 3
537 five();
538 six();
539 seven();
540 "}
541 .to_string(),
542 expected_patches: vec![
543 indoc! {"
544 --- a/src/main.rs
545 +++ b/src/main.rs
546 @@ -1,16 +1,16 @@
547 fn main() {
548 // comment 1
549 one();
550 two();
551 // comment 4
552 three();
553 four();
554 // comment 3
555 five();
556 six();
557 seven();
558 eight();
559 // comment 2
560 nine();
561 }
562 "}
563 .to_string()
564 ],
565 rejected_patch: Some(
566 indoc! {"
567 --- a/src/main.rs
568 +++ b/src/main.rs
569 @@ -1,16 +1,16 @@
570 fn main() {
571 // comment 1
572 one();
573 two();
574 // comment 4
575 three();
576 four();
577 // comment 3
578 five();
579 six();
580 seven();
581 eight();
582 // comment 2
583 nine();
584 }
585 "}
586 .to_string()
587 ),
588 captured_prompt_input: example.captured_prompt_input.clone(),
589 telemetry: None,
590 human_feedback: Vec::new(),
591 rating: None,
592 }
593 );
594
595 let prompt_input = example
596 .captured_prompt_input
597 .expect("should have captured prompt input");
598 assert!(
599 prompt_input.cursor_file_content.contains("fn main()"),
600 "cursor_file_content should contain file content"
601 );
602 assert_eq!(
603 prompt_input.cursor_offset, 0,
604 "cursor at Anchor::MIN should be offset 0"
605 );
606 assert_eq!(
607 prompt_input.cursor_row, 0,
608 "cursor at Anchor::MIN should be row 0"
609 );
610 assert_eq!(
611 prompt_input.cursor_column, 0,
612 "cursor at Anchor::MIN should be column 0"
613 );
614 assert!(prompt_input.events.len() > 0, "should have captured events");
615 assert_eq!(
616 prompt_input.related_files.len(),
617 0,
618 "should have no related files (none passed)"
619 );
620 }
621
622 fn init_test(cx: &mut TestAppContext) {
623 cx.update(|cx| {
624 let settings_store = SettingsStore::test(cx);
625 cx.set_global(settings_store);
626 zlog::init_test();
627 let http_client = FakeHttpClient::with_404_response();
628 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
629 language_model::init(client.clone(), cx);
630 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
631 EditPredictionStore::global(&client, &user_store, cx);
632 })
633 }
634}