1use crate::{
2 StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position,
4 example_spec::{
5 CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile,
6 ExampleSpec, MAX_CURSOR_FILE_SIZE,
7 },
8};
9use anyhow::Result;
10use buffer_diff::BufferDiffSnapshot;
11use collections::HashMap;
12use gpui::{App, Entity, Task};
13use language::{Buffer, ToPoint as _};
14use project::{Project, WorktreeId};
15use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
16use text::{BufferSnapshot as TextBufferSnapshot, Point, ToOffset as _};
17
18pub(crate) const ZETA2_TESTING_RATE_PER_10K_PREDICTION: u16 = 500;
19
20pub fn capture_example(
21 project: Entity<Project>,
22 buffer: Entity<Buffer>,
23 cursor_anchor: language::Anchor,
24 mut events: Vec<StoredEvent>,
25 related_files: Vec<zeta_prompt::RelatedFile>,
26 populate_expected_patch: bool,
27 cx: &mut App,
28) -> Option<Task<Result<ExampleSpec>>> {
29 let snapshot = buffer.read(cx).snapshot();
30 let file = snapshot.file()?;
31 let worktree_id = file.worktree_id(cx);
32 let repository = project.read(cx).active_repository(cx)?;
33 let repository_snapshot = repository.read(cx).snapshot();
34 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
35 let root_name = worktree.read(cx).root_name_str().to_owned();
36 let cursor_path: Arc<Path> = file.path().as_std_path().into();
37 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
38 return None;
39 }
40
41 let repository_url = repository_snapshot
42 .remote_origin_url
43 .clone()
44 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
45 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
46
47 let git_store = project.read(cx).git_store().clone();
48
49 Some(cx.spawn(async move |mut cx| {
50 let snapshots_by_path =
51 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
52
53 events.retain(|stored_event| {
54 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
55 let relative_path = strip_root_name(path, &root_name);
56 snapshots_by_path.contains_key(relative_path)
57 });
58
59 let line_comment_prefix = snapshot
60 .language()
61 .and_then(|lang| lang.config().line_comments.first())
62 .map(|s| s.to_string())
63 .unwrap_or_default();
64
65 let full_cursor_offset = cursor_anchor.to_offset(&snapshot);
66 let cursor_point = cursor_anchor.to_point(&snapshot);
67 let cursor_file_content = if snapshot.len() <= MAX_CURSOR_FILE_SIZE {
68 Some(snapshot.text())
69 } else {
70 None
71 };
72
73 let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
74 .background_executor()
75 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
76 .await;
77 let uncommitted_diff = cx
78 .background_executor()
79 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
80 .await;
81
82 let mut edit_history = String::new();
83 for stored_event in &events {
84 write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
85 if !edit_history.ends_with('\n') {
86 edit_history.push('\n');
87 }
88 }
89
90 // Initialize an empty patch with context lines, to make it easy
91 // to write the expected patch by hand.
92 let mut expected_patches = Vec::new();
93 let mut rejected_patch = None;
94 if populate_expected_patch {
95 let mut empty_patch = String::new();
96 let start_row = cursor_excerpt_range.start.row + 1;
97 let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
98 writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
99 writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
100 writeln!(
101 &mut empty_patch,
102 "@@ -{},{} +{},{} @@",
103 start_row, row_count, start_row, row_count,
104 )
105 .ok();
106 for line in cursor_excerpt.lines() {
107 writeln!(&mut empty_patch, " {}", line).ok();
108 }
109
110 expected_patches.push(empty_patch.clone());
111 rejected_patch = Some(empty_patch);
112 }
113
114 let prompt_input = cursor_file_content.map(|content| {
115 let captured_events: Vec<CapturedEvent> = events
116 .iter()
117 .map(|stored_event| {
118 let zeta_prompt::Event::BufferChange {
119 path,
120 old_path,
121 diff,
122 predicted,
123 in_open_source_repo,
124 } = stored_event.event.as_ref();
125 CapturedEvent {
126 path: strip_root_name(path, &root_name).into(),
127 old_path: strip_root_name(old_path, &root_name).into(),
128 diff: diff.clone(),
129 predicted: *predicted,
130 in_open_source_repo: *in_open_source_repo,
131 }
132 })
133 .collect();
134
135 let captured_related_files: Vec<CapturedRelatedFile> = related_files
136 .iter()
137 .map(|rf| CapturedRelatedFile {
138 path: strip_root_name(&rf.path, &root_name).into(),
139 max_row: rf.max_row,
140 excerpts: rf
141 .excerpts
142 .iter()
143 .map(|e| CapturedRelatedExcerpt {
144 row_range: e.row_range.clone(),
145 text: e.text.to_string(),
146 })
147 .collect(),
148 })
149 .collect();
150
151 CapturedPromptInput {
152 cursor_file_content: content,
153 cursor_offset: full_cursor_offset,
154 cursor_row: cursor_point.row,
155 cursor_column: cursor_point.column,
156 excerpt_start_row: Some(0),
157 events: captured_events,
158 related_files: captured_related_files,
159 }
160 });
161
162 let mut spec = ExampleSpec {
163 name: generate_timestamp_name(),
164 repository_url,
165 revision,
166 tags: Vec::new(),
167 reasoning: None,
168 uncommitted_diff,
169 cursor_path,
170 cursor_position: String::new(),
171 edit_history,
172 expected_patches,
173 rejected_patch,
174 captured_prompt_input: prompt_input,
175 telemetry: None,
176 human_feedback: Vec::new(),
177 rating: None,
178 };
179 spec.set_cursor_excerpt(
180 &cursor_excerpt,
181 cursor_offset_in_excerpt,
182 &line_comment_prefix,
183 );
184 Ok(spec)
185 }))
186}
187
188fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
189 path.strip_prefix(root_name).unwrap_or(path)
190}
191
192fn write_event_with_relative_paths(
193 output: &mut String,
194 event: &zeta_prompt::Event,
195 root_name: &str,
196) {
197 fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
198 for component in strip_root_name(path, root_name).components() {
199 output.push('/');
200 write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
201 }
202 }
203
204 let zeta_prompt::Event::BufferChange {
205 path,
206 old_path,
207 diff,
208 ..
209 } = event;
210
211 output.push_str("--- a");
212 write_relative_path(output, old_path.as_ref(), root_name);
213 output.push_str("\n+++ b");
214 write_relative_path(output, path.as_ref(), root_name);
215 output.push('\n');
216 output.push_str(diff);
217}
218
219fn compute_cursor_excerpt(
220 snapshot: &language::BufferSnapshot,
221 cursor_anchor: language::Anchor,
222) -> (String, usize, Range<Point>) {
223 use text::ToOffset as _;
224
225 let cursor_point = cursor_anchor.to_point(snapshot);
226 let (_editable_range, context_range) =
227 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
228 let context_start_offset = context_range.start.to_offset(snapshot);
229 let cursor_offset = cursor_anchor.to_offset(snapshot);
230 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
231 let excerpt = snapshot
232 .text_for_range(context_range.clone())
233 .collect::<String>();
234 (excerpt, cursor_offset_in_excerpt, context_range)
235}
236
237async fn collect_snapshots(
238 project: &Entity<Project>,
239 git_store: &Entity<project::git_store::GitStore>,
240 worktree_id: WorktreeId,
241 events: &[StoredEvent],
242 cx: &mut gpui::AsyncApp,
243) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
244 let mut snapshots_by_path = HashMap::default();
245 for stored_event in events {
246 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
247 if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
248 let project_path = project
249 .find_project_path(path, cx)
250 .filter(|path| path.worktree_id == worktree_id)?;
251 let relative_path: Arc<Path> = project_path.path.as_std_path().into();
252 Some((project_path, relative_path))
253 }) {
254 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
255 let buffer = project
256 .update(cx, |project, cx| {
257 project.open_buffer(project_path.clone(), cx)
258 })
259 .await?;
260 let diff = git_store
261 .update(cx, |git_store, cx| {
262 git_store.open_uncommitted_diff(buffer.clone(), cx)
263 })
264 .await?;
265 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
266 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
267 }
268 }
269 }
270 Ok(snapshots_by_path)
271}
272
273fn compute_uncommitted_diff(
274 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
275) -> String {
276 let mut uncommitted_diff = String::new();
277 for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
278 if let Some(head_text) = &diff_snapshot.base_text_string() {
279 let file_diff = language::unified_diff(head_text, &before_text.text());
280 if !file_diff.is_empty() {
281 let path_str = relative_path.to_string_lossy();
282 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
283 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
284 uncommitted_diff.push_str(&file_diff);
285 if !uncommitted_diff.ends_with('\n') {
286 uncommitted_diff.push('\n');
287 }
288 }
289 }
290 }
291 uncommitted_diff
292}
293
294fn generate_timestamp_name() -> String {
295 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
296 match format {
297 Ok(format) => {
298 let now = time::OffsetDateTime::now_local()
299 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
300 now.format(&format)
301 .unwrap_or_else(|_| "unknown-time".to_string())
302 }
303 Err(_) => "unknown-time".to_string(),
304 }
305}
306
307pub(crate) fn should_send_testing_zeta2_request() -> bool {
308 rand::random::<u16>() % 10_000 < ZETA2_TESTING_RATE_PER_10K_PREDICTION
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use crate::EditPredictionStore;
315 use client::{Client, UserStore};
316 use clock::FakeSystemClock;
317 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
318 use indoc::indoc;
319 use language::{Anchor, Point};
320 use project::{FakeFs, Project};
321 use serde_json::json;
322 use settings::SettingsStore;
323 use std::path::Path;
324
325 #[gpui::test]
326 async fn test_capture_example(cx: &mut TestAppContext) {
327 init_test(cx);
328 let fs = FakeFs::new(cx.executor());
329
330 let committed_contents = indoc! {"
331 fn main() {
332 one();
333 two();
334 three();
335 four();
336 five();
337 six();
338 seven();
339 eight();
340 nine();
341 }
342 "};
343
344 let disk_contents = indoc! {"
345 fn main() {
346 // comment 1
347 one();
348 two();
349 three();
350 four();
351 five();
352 six();
353 seven();
354 eight();
355 // comment 2
356 nine();
357 }
358 "};
359
360 fs.insert_tree(
361 "/project",
362 json!({
363 ".git": {},
364 "src": {
365 "main.rs": disk_contents,
366 }
367 }),
368 )
369 .await;
370
371 // Create an external file outside the main project
372 fs.insert_tree(
373 "/external",
374 json!({
375 "external.rs": "fn external() {}\n",
376 }),
377 )
378 .await;
379
380 fs.set_head_for_repo(
381 Path::new("/project/.git"),
382 &[("src/main.rs", committed_contents.to_string())],
383 "abc123def456",
384 );
385 fs.set_remote_for_repo(
386 Path::new("/project/.git"),
387 "origin",
388 "https://github.com/test/repo.git",
389 );
390
391 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
392
393 let buffer = project
394 .update(cx, |project, cx| {
395 project.open_local_buffer("/project/src/main.rs", cx)
396 })
397 .await
398 .unwrap();
399
400 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
401 ep_store.update(cx, |ep_store, cx| {
402 ep_store.register_buffer(&buffer, &project, cx)
403 });
404 cx.run_until_parked();
405
406 buffer.update(cx, |buffer, cx| {
407 let point = Point::new(6, 0);
408 buffer.edit([(point..point, " // comment 3\n")], None, cx);
409 let point = Point::new(4, 0);
410 buffer.edit([(point..point, " // comment 4\n")], None, cx);
411
412 pretty_assertions::assert_eq!(
413 buffer.text(),
414 indoc! {"
415 fn main() {
416 // comment 1
417 one();
418 two();
419 // comment 4
420 three();
421 four();
422 // comment 3
423 five();
424 six();
425 seven();
426 eight();
427 // comment 2
428 nine();
429 }
430 "}
431 );
432 });
433 cx.run_until_parked();
434
435 // Open and edit an external file (outside the main project's worktree)
436 let external_buffer = project
437 .update(cx, |project, cx| {
438 project.open_local_buffer("/external/external.rs", cx)
439 })
440 .await
441 .unwrap();
442 ep_store.update(cx, |ep_store, cx| {
443 ep_store.register_buffer(&external_buffer, &project, cx)
444 });
445 cx.run_until_parked();
446 external_buffer.update(cx, |buffer, cx| {
447 let point = Point::new(0, 0);
448 buffer.edit([(point..point, "// external edit\n")], None, cx);
449 });
450 cx.run_until_parked();
451
452 // Verify the external edit was recorded in events
453 let events = ep_store.update(cx, |store, cx| store.edit_history_for_project(&project, cx));
454 assert!(
455 matches!(
456 events
457 .last()
458 .unwrap()
459 .event
460 .as_ref(),
461 zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
462 ),
463 "external file edit should be in events"
464 );
465
466 let mut example = cx
467 .update(|cx| {
468 capture_example(
469 project.clone(),
470 buffer.clone(),
471 Anchor::MIN,
472 events,
473 Vec::new(),
474 true,
475 cx,
476 )
477 .unwrap()
478 })
479 .await
480 .unwrap();
481 example.name = "test".to_string();
482
483 pretty_assertions::assert_eq!(
484 example,
485 ExampleSpec {
486 name: "test".to_string(),
487 repository_url: "https://github.com/test/repo.git".to_string(),
488 revision: "abc123def456".to_string(),
489 tags: Vec::new(),
490 reasoning: None,
491 uncommitted_diff: indoc! {"
492 --- a/src/main.rs
493 +++ b/src/main.rs
494 @@ -1,4 +1,5 @@
495 fn main() {
496 + // comment 1
497 one();
498 two();
499 three();
500 @@ -7,5 +8,6 @@
501 six();
502 seven();
503 eight();
504 + // comment 2
505 nine();
506 }
507 "}
508 .to_string(),
509 cursor_path: Path::new("src/main.rs").into(),
510 cursor_position: indoc! {"
511 fn main() {
512 ^[CURSOR_POSITION]
513 // comment 1
514 one();
515 two();
516 // comment 4
517 three();
518 four();
519 // comment 3
520 five();
521 six();
522 seven();
523 eight();
524 // comment 2
525 nine();
526 }
527 "}
528 .to_string(),
529 edit_history: indoc! {"
530 --- a/src/main.rs
531 +++ b/src/main.rs
532 @@ -2,8 +2,10 @@
533 // comment 1
534 one();
535 two();
536 + // comment 4
537 three();
538 four();
539 + // comment 3
540 five();
541 six();
542 seven();
543 "}
544 .to_string(),
545 expected_patches: vec![
546 indoc! {"
547 --- a/src/main.rs
548 +++ b/src/main.rs
549 @@ -1,16 +1,16 @@
550 fn main() {
551 // comment 1
552 one();
553 two();
554 // comment 4
555 three();
556 four();
557 // comment 3
558 five();
559 six();
560 seven();
561 eight();
562 // comment 2
563 nine();
564 }
565 "}
566 .to_string()
567 ],
568 rejected_patch: Some(
569 indoc! {"
570 --- a/src/main.rs
571 +++ b/src/main.rs
572 @@ -1,16 +1,16 @@
573 fn main() {
574 // comment 1
575 one();
576 two();
577 // comment 4
578 three();
579 four();
580 // comment 3
581 five();
582 six();
583 seven();
584 eight();
585 // comment 2
586 nine();
587 }
588 "}
589 .to_string()
590 ),
591 captured_prompt_input: example.captured_prompt_input.clone(),
592 telemetry: None,
593 human_feedback: Vec::new(),
594 rating: None,
595 }
596 );
597
598 let prompt_input = example
599 .captured_prompt_input
600 .expect("should have captured prompt input");
601 assert!(
602 prompt_input.cursor_file_content.contains("fn main()"),
603 "cursor_file_content should contain file content"
604 );
605 assert_eq!(
606 prompt_input.cursor_offset, 0,
607 "cursor at Anchor::MIN should be offset 0"
608 );
609 assert_eq!(
610 prompt_input.cursor_row, 0,
611 "cursor at Anchor::MIN should be row 0"
612 );
613 assert_eq!(
614 prompt_input.cursor_column, 0,
615 "cursor at Anchor::MIN should be column 0"
616 );
617 assert!(prompt_input.events.len() > 0, "should have captured events");
618 assert_eq!(
619 prompt_input.related_files.len(),
620 0,
621 "should have no related files (none passed)"
622 );
623 }
624
625 fn init_test(cx: &mut TestAppContext) {
626 cx.update(|cx| {
627 let settings_store = SettingsStore::test(cx);
628 cx.set_global(settings_store);
629 zlog::init_test();
630 let http_client = FakeHttpClient::with_404_response();
631 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
632 language_model::init(client.clone(), cx);
633 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
634 EditPredictionStore::global(&client, &user_store, cx);
635 })
636 }
637}