1use crate::{
2 EditPredictionExampleCaptureFeatureFlag, StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position,
4 example_spec::{
5 CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile,
6 ExampleSpec, MAX_CURSOR_FILE_SIZE,
7 },
8};
9use anyhow::Result;
10use buffer_diff::BufferDiffSnapshot;
11use collections::HashMap;
12use feature_flags::FeatureFlagAppExt as _;
13use gpui::{App, Entity, Task};
14use language::{Buffer, ToPoint as _};
15use project::{Project, WorktreeId};
16use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
17use text::{BufferSnapshot as TextBufferSnapshot, Point, ToOffset as _};
18
19pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
20pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
21
22pub fn capture_example(
23 project: Entity<Project>,
24 buffer: Entity<Buffer>,
25 cursor_anchor: language::Anchor,
26 mut events: Vec<StoredEvent>,
27 related_files: Vec<zeta_prompt::RelatedFile>,
28 populate_expected_patch: bool,
29 cx: &mut App,
30) -> Option<Task<Result<ExampleSpec>>> {
31 let snapshot = buffer.read(cx).snapshot();
32 let file = snapshot.file()?;
33 let worktree_id = file.worktree_id(cx);
34 let repository = project.read(cx).active_repository(cx)?;
35 let repository_snapshot = repository.read(cx).snapshot();
36 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
37 let root_name = worktree.read(cx).root_name_str().to_owned();
38 let cursor_path: Arc<Path> = file.path().as_std_path().into();
39 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
40 return None;
41 }
42
43 let repository_url = repository_snapshot
44 .remote_origin_url
45 .clone()
46 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
47 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
48
49 let git_store = project.read(cx).git_store().clone();
50
51 Some(cx.spawn(async move |mut cx| {
52 let snapshots_by_path =
53 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
54
55 events.retain(|stored_event| {
56 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
57 let relative_path = strip_root_name(path, &root_name);
58 snapshots_by_path.contains_key(relative_path)
59 });
60
61 let line_comment_prefix = snapshot
62 .language()
63 .and_then(|lang| lang.config().line_comments.first())
64 .map(|s| s.to_string())
65 .unwrap_or_default();
66
67 let full_cursor_offset = cursor_anchor.to_offset(&snapshot);
68 let cursor_point = cursor_anchor.to_point(&snapshot);
69 let cursor_file_content = if snapshot.len() <= MAX_CURSOR_FILE_SIZE {
70 Some(snapshot.text())
71 } else {
72 None
73 };
74
75 let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
76 .background_executor()
77 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
78 .await;
79 let uncommitted_diff = cx
80 .background_executor()
81 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
82 .await;
83
84 let mut edit_history = String::new();
85 for stored_event in &events {
86 write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
87 if !edit_history.ends_with('\n') {
88 edit_history.push('\n');
89 }
90 }
91
92 // Initialize an empty patch with context lines, to make it easy
93 // to write the expected patch by hand.
94 let mut expected_patches = Vec::new();
95 let mut rejected_patch = None;
96 if populate_expected_patch {
97 let mut empty_patch = String::new();
98 let start_row = cursor_excerpt_range.start.row + 1;
99 let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
100 writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
101 writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
102 writeln!(
103 &mut empty_patch,
104 "@@ -{},{} +{},{} @@",
105 start_row, row_count, start_row, row_count,
106 )
107 .ok();
108 for line in cursor_excerpt.lines() {
109 writeln!(&mut empty_patch, " {}", line).ok();
110 }
111
112 expected_patches.push(empty_patch.clone());
113 rejected_patch = Some(empty_patch);
114 }
115
116 let prompt_input = cursor_file_content.map(|content| {
117 let captured_events: Vec<CapturedEvent> = events
118 .iter()
119 .map(|stored_event| {
120 let zeta_prompt::Event::BufferChange {
121 path,
122 old_path,
123 diff,
124 predicted,
125 in_open_source_repo,
126 } = stored_event.event.as_ref();
127 CapturedEvent {
128 path: strip_root_name(path, &root_name).into(),
129 old_path: strip_root_name(old_path, &root_name).into(),
130 diff: diff.clone(),
131 predicted: *predicted,
132 in_open_source_repo: *in_open_source_repo,
133 }
134 })
135 .collect();
136
137 let captured_related_files: Vec<CapturedRelatedFile> = related_files
138 .iter()
139 .map(|rf| CapturedRelatedFile {
140 path: strip_root_name(&rf.path, &root_name).into(),
141 max_row: rf.max_row,
142 excerpts: rf
143 .excerpts
144 .iter()
145 .map(|e| CapturedRelatedExcerpt {
146 row_range: e.row_range.clone(),
147 text: e.text.to_string(),
148 })
149 .collect(),
150 })
151 .collect();
152
153 CapturedPromptInput {
154 cursor_file_content: content,
155 cursor_offset: full_cursor_offset,
156 cursor_row: cursor_point.row,
157 cursor_column: cursor_point.column,
158 events: captured_events,
159 related_files: captured_related_files,
160 }
161 });
162
163 let mut spec = ExampleSpec {
164 name: generate_timestamp_name(),
165 repository_url,
166 revision,
167 tags: Vec::new(),
168 reasoning: None,
169 uncommitted_diff,
170 cursor_path,
171 cursor_position: String::new(),
172 edit_history,
173 expected_patches,
174 rejected_patch,
175 captured_prompt_input: prompt_input,
176 };
177 spec.set_cursor_excerpt(
178 &cursor_excerpt,
179 cursor_offset_in_excerpt,
180 &line_comment_prefix,
181 );
182 Ok(spec)
183 }))
184}
185
186fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
187 path.strip_prefix(root_name).unwrap_or(path)
188}
189
190fn write_event_with_relative_paths(
191 output: &mut String,
192 event: &zeta_prompt::Event,
193 root_name: &str,
194) {
195 fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
196 for component in strip_root_name(path, root_name).components() {
197 output.push('/');
198 write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
199 }
200 }
201
202 let zeta_prompt::Event::BufferChange {
203 path,
204 old_path,
205 diff,
206 ..
207 } = event;
208
209 output.push_str("--- a");
210 write_relative_path(output, old_path.as_ref(), root_name);
211 output.push_str("\n+++ b");
212 write_relative_path(output, path.as_ref(), root_name);
213 output.push('\n');
214 output.push_str(diff);
215}
216
217fn compute_cursor_excerpt(
218 snapshot: &language::BufferSnapshot,
219 cursor_anchor: language::Anchor,
220) -> (String, usize, Range<Point>) {
221 use text::ToOffset as _;
222
223 let cursor_point = cursor_anchor.to_point(snapshot);
224 let (_editable_range, context_range) =
225 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
226 let context_start_offset = context_range.start.to_offset(snapshot);
227 let cursor_offset = cursor_anchor.to_offset(snapshot);
228 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
229 let excerpt = snapshot
230 .text_for_range(context_range.clone())
231 .collect::<String>();
232 (excerpt, cursor_offset_in_excerpt, context_range)
233}
234
235async fn collect_snapshots(
236 project: &Entity<Project>,
237 git_store: &Entity<project::git_store::GitStore>,
238 worktree_id: WorktreeId,
239 events: &[StoredEvent],
240 cx: &mut gpui::AsyncApp,
241) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
242 let mut snapshots_by_path = HashMap::default();
243 for stored_event in events {
244 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
245 if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
246 let project_path = project
247 .find_project_path(path, cx)
248 .filter(|path| path.worktree_id == worktree_id)?;
249 let relative_path: Arc<Path> = project_path.path.as_std_path().into();
250 Some((project_path, relative_path))
251 }) {
252 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
253 let buffer = project
254 .update(cx, |project, cx| {
255 project.open_buffer(project_path.clone(), cx)
256 })
257 .await?;
258 let diff = git_store
259 .update(cx, |git_store, cx| {
260 git_store.open_uncommitted_diff(buffer.clone(), cx)
261 })
262 .await?;
263 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
264 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
265 }
266 }
267 }
268 Ok(snapshots_by_path)
269}
270
271fn compute_uncommitted_diff(
272 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
273) -> String {
274 let mut uncommitted_diff = String::new();
275 for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
276 if let Some(head_text) = &diff_snapshot.base_text_string() {
277 let file_diff = language::unified_diff(head_text, &before_text.text());
278 if !file_diff.is_empty() {
279 let path_str = relative_path.to_string_lossy();
280 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
281 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
282 uncommitted_diff.push_str(&file_diff);
283 if !uncommitted_diff.ends_with('\n') {
284 uncommitted_diff.push('\n');
285 }
286 }
287 }
288 }
289 uncommitted_diff
290}
291
292fn generate_timestamp_name() -> String {
293 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
294 match format {
295 Ok(format) => {
296 let now = time::OffsetDateTime::now_local()
297 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
298 now.format(&format)
299 .unwrap_or_else(|_| "unknown-time".to_string())
300 }
301 Err(_) => "unknown-time".to_string(),
302 }
303}
304
305pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
306 let default_rate = if cx.is_staff() {
307 DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
308 } else {
309 DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
310 };
311 let capture_rate = language::language_settings::all_language_settings(None, cx)
312 .edit_predictions
313 .example_capture_rate
314 .unwrap_or(default_rate);
315 cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
316 && rand::random::<u16>() % 10_000 < capture_rate
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use crate::EditPredictionStore;
323 use client::{Client, UserStore};
324 use clock::FakeSystemClock;
325 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
326 use indoc::indoc;
327 use language::{Anchor, Point};
328 use project::{FakeFs, Project};
329 use serde_json::json;
330 use settings::SettingsStore;
331 use std::path::Path;
332
333 #[gpui::test]
334 async fn test_capture_example(cx: &mut TestAppContext) {
335 init_test(cx);
336 let fs = FakeFs::new(cx.executor());
337
338 let committed_contents = indoc! {"
339 fn main() {
340 one();
341 two();
342 three();
343 four();
344 five();
345 six();
346 seven();
347 eight();
348 nine();
349 }
350 "};
351
352 let disk_contents = indoc! {"
353 fn main() {
354 // comment 1
355 one();
356 two();
357 three();
358 four();
359 five();
360 six();
361 seven();
362 eight();
363 // comment 2
364 nine();
365 }
366 "};
367
368 fs.insert_tree(
369 "/project",
370 json!({
371 ".git": {},
372 "src": {
373 "main.rs": disk_contents,
374 }
375 }),
376 )
377 .await;
378
379 // Create an external file outside the main project
380 fs.insert_tree(
381 "/external",
382 json!({
383 "external.rs": "fn external() {}\n",
384 }),
385 )
386 .await;
387
388 fs.set_head_for_repo(
389 Path::new("/project/.git"),
390 &[("src/main.rs", committed_contents.to_string())],
391 "abc123def456",
392 );
393 fs.set_remote_for_repo(
394 Path::new("/project/.git"),
395 "origin",
396 "https://github.com/test/repo.git",
397 );
398
399 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
400
401 let buffer = project
402 .update(cx, |project, cx| {
403 project.open_local_buffer("/project/src/main.rs", cx)
404 })
405 .await
406 .unwrap();
407
408 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
409 ep_store.update(cx, |ep_store, cx| {
410 ep_store.register_buffer(&buffer, &project, cx)
411 });
412 cx.run_until_parked();
413
414 buffer.update(cx, |buffer, cx| {
415 let point = Point::new(6, 0);
416 buffer.edit([(point..point, " // comment 3\n")], None, cx);
417 let point = Point::new(4, 0);
418 buffer.edit([(point..point, " // comment 4\n")], None, cx);
419
420 pretty_assertions::assert_eq!(
421 buffer.text(),
422 indoc! {"
423 fn main() {
424 // comment 1
425 one();
426 two();
427 // comment 4
428 three();
429 four();
430 // comment 3
431 five();
432 six();
433 seven();
434 eight();
435 // comment 2
436 nine();
437 }
438 "}
439 );
440 });
441 cx.run_until_parked();
442
443 // Open and edit an external file (outside the main project's worktree)
444 let external_buffer = project
445 .update(cx, |project, cx| {
446 project.open_local_buffer("/external/external.rs", cx)
447 })
448 .await
449 .unwrap();
450 ep_store.update(cx, |ep_store, cx| {
451 ep_store.register_buffer(&external_buffer, &project, cx)
452 });
453 cx.run_until_parked();
454 external_buffer.update(cx, |buffer, cx| {
455 let point = Point::new(0, 0);
456 buffer.edit([(point..point, "// external edit\n")], None, cx);
457 });
458 cx.run_until_parked();
459
460 // Verify the external edit was recorded in events
461 let events = ep_store.update(cx, |store, cx| {
462 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
463 });
464 assert!(
465 matches!(
466 events
467 .last()
468 .unwrap()
469 .event
470 .as_ref(),
471 zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
472 ),
473 "external file edit should be in events"
474 );
475
476 let mut example = cx
477 .update(|cx| {
478 capture_example(
479 project.clone(),
480 buffer.clone(),
481 Anchor::MIN,
482 events,
483 Vec::new(),
484 true,
485 cx,
486 )
487 .unwrap()
488 })
489 .await
490 .unwrap();
491 example.name = "test".to_string();
492
493 pretty_assertions::assert_eq!(
494 example,
495 ExampleSpec {
496 name: "test".to_string(),
497 repository_url: "https://github.com/test/repo.git".to_string(),
498 revision: "abc123def456".to_string(),
499 tags: Vec::new(),
500 reasoning: None,
501 uncommitted_diff: indoc! {"
502 --- a/src/main.rs
503 +++ b/src/main.rs
504 @@ -1,4 +1,5 @@
505 fn main() {
506 + // comment 1
507 one();
508 two();
509 three();
510 @@ -7,5 +8,6 @@
511 six();
512 seven();
513 eight();
514 + // comment 2
515 nine();
516 }
517 "}
518 .to_string(),
519 cursor_path: Path::new("src/main.rs").into(),
520 cursor_position: indoc! {"
521 fn main() {
522 ^[CURSOR_POSITION]
523 // comment 1
524 one();
525 two();
526 // comment 4
527 three();
528 four();
529 // comment 3
530 five();
531 six();
532 seven();
533 eight();
534 // comment 2
535 nine();
536 }
537 "}
538 .to_string(),
539 edit_history: indoc! {"
540 --- a/src/main.rs
541 +++ b/src/main.rs
542 @@ -2,8 +2,10 @@
543 // comment 1
544 one();
545 two();
546 + // comment 4
547 three();
548 four();
549 + // comment 3
550 five();
551 six();
552 seven();
553 "}
554 .to_string(),
555 expected_patches: vec![
556 indoc! {"
557 --- a/src/main.rs
558 +++ b/src/main.rs
559 @@ -1,16 +1,16 @@
560 fn main() {
561 // comment 1
562 one();
563 two();
564 // comment 4
565 three();
566 four();
567 // comment 3
568 five();
569 six();
570 seven();
571 eight();
572 // comment 2
573 nine();
574 }
575 "}
576 .to_string()
577 ],
578 rejected_patch: Some(
579 indoc! {"
580 --- a/src/main.rs
581 +++ b/src/main.rs
582 @@ -1,16 +1,16 @@
583 fn main() {
584 // comment 1
585 one();
586 two();
587 // comment 4
588 three();
589 four();
590 // comment 3
591 five();
592 six();
593 seven();
594 eight();
595 // comment 2
596 nine();
597 }
598 "}
599 .to_string()
600 ),
601 captured_prompt_input: example.captured_prompt_input.clone(),
602 }
603 );
604
605 let prompt_input = example
606 .captured_prompt_input
607 .expect("should have captured prompt input");
608 assert!(
609 prompt_input.cursor_file_content.contains("fn main()"),
610 "cursor_file_content should contain file content"
611 );
612 assert_eq!(
613 prompt_input.cursor_offset, 0,
614 "cursor at Anchor::MIN should be offset 0"
615 );
616 assert_eq!(
617 prompt_input.cursor_row, 0,
618 "cursor at Anchor::MIN should be row 0"
619 );
620 assert_eq!(
621 prompt_input.cursor_column, 0,
622 "cursor at Anchor::MIN should be column 0"
623 );
624 assert!(prompt_input.events.len() > 0, "should have captured events");
625 assert_eq!(
626 prompt_input.related_files.len(),
627 0,
628 "should have no related files (none passed)"
629 );
630 }
631
632 fn init_test(cx: &mut TestAppContext) {
633 cx.update(|cx| {
634 let settings_store = SettingsStore::test(cx);
635 cx.set_global(settings_store);
636 zlog::init_test();
637 let http_client = FakeHttpClient::with_404_response();
638 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
639 language_model::init(client.clone(), cx);
640 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
641 EditPredictionStore::global(&client, &user_store, cx);
642 })
643 }
644}