1use crate::{
2 EditPredictionExampleCaptureFeatureFlag, StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position,
4 example_spec::{
5 CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile,
6 ExampleSpec, MAX_CURSOR_FILE_SIZE,
7 },
8};
9use anyhow::Result;
10use buffer_diff::BufferDiffSnapshot;
11use collections::HashMap;
12use feature_flags::FeatureFlagAppExt as _;
13use gpui::{App, Entity, Task};
14use language::{Buffer, ToPoint as _};
15use project::{Project, WorktreeId};
16use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
17use text::{BufferSnapshot as TextBufferSnapshot, Point, ToOffset as _};
18
19pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
20pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
21pub(crate) const ZETA2_TESTING_RATE_PER_10K_PREDICTION: u16 = 500;
22
23pub fn capture_example(
24 project: Entity<Project>,
25 buffer: Entity<Buffer>,
26 cursor_anchor: language::Anchor,
27 mut events: Vec<StoredEvent>,
28 related_files: Vec<zeta_prompt::RelatedFile>,
29 populate_expected_patch: bool,
30 cx: &mut App,
31) -> Option<Task<Result<ExampleSpec>>> {
32 let snapshot = buffer.read(cx).snapshot();
33 let file = snapshot.file()?;
34 let worktree_id = file.worktree_id(cx);
35 let repository = project.read(cx).active_repository(cx)?;
36 let repository_snapshot = repository.read(cx).snapshot();
37 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
38 let root_name = worktree.read(cx).root_name_str().to_owned();
39 let cursor_path: Arc<Path> = file.path().as_std_path().into();
40 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
41 return None;
42 }
43
44 let repository_url = repository_snapshot
45 .remote_origin_url
46 .clone()
47 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
48 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
49
50 let git_store = project.read(cx).git_store().clone();
51
52 Some(cx.spawn(async move |mut cx| {
53 let snapshots_by_path =
54 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
55
56 events.retain(|stored_event| {
57 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
58 let relative_path = strip_root_name(path, &root_name);
59 snapshots_by_path.contains_key(relative_path)
60 });
61
62 let line_comment_prefix = snapshot
63 .language()
64 .and_then(|lang| lang.config().line_comments.first())
65 .map(|s| s.to_string())
66 .unwrap_or_default();
67
68 let full_cursor_offset = cursor_anchor.to_offset(&snapshot);
69 let cursor_point = cursor_anchor.to_point(&snapshot);
70 let cursor_file_content = if snapshot.len() <= MAX_CURSOR_FILE_SIZE {
71 Some(snapshot.text())
72 } else {
73 None
74 };
75
76 let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
77 .background_executor()
78 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
79 .await;
80 let uncommitted_diff = cx
81 .background_executor()
82 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
83 .await;
84
85 let mut edit_history = String::new();
86 for stored_event in &events {
87 write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
88 if !edit_history.ends_with('\n') {
89 edit_history.push('\n');
90 }
91 }
92
93 // Initialize an empty patch with context lines, to make it easy
94 // to write the expected patch by hand.
95 let mut expected_patches = Vec::new();
96 let mut rejected_patch = None;
97 if populate_expected_patch {
98 let mut empty_patch = String::new();
99 let start_row = cursor_excerpt_range.start.row + 1;
100 let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
101 writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
102 writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
103 writeln!(
104 &mut empty_patch,
105 "@@ -{},{} +{},{} @@",
106 start_row, row_count, start_row, row_count,
107 )
108 .ok();
109 for line in cursor_excerpt.lines() {
110 writeln!(&mut empty_patch, " {}", line).ok();
111 }
112
113 expected_patches.push(empty_patch.clone());
114 rejected_patch = Some(empty_patch);
115 }
116
117 let prompt_input = cursor_file_content.map(|content| {
118 let captured_events: Vec<CapturedEvent> = events
119 .iter()
120 .map(|stored_event| {
121 let zeta_prompt::Event::BufferChange {
122 path,
123 old_path,
124 diff,
125 predicted,
126 in_open_source_repo,
127 } = stored_event.event.as_ref();
128 CapturedEvent {
129 path: strip_root_name(path, &root_name).into(),
130 old_path: strip_root_name(old_path, &root_name).into(),
131 diff: diff.clone(),
132 predicted: *predicted,
133 in_open_source_repo: *in_open_source_repo,
134 }
135 })
136 .collect();
137
138 let captured_related_files: Vec<CapturedRelatedFile> = related_files
139 .iter()
140 .map(|rf| CapturedRelatedFile {
141 path: strip_root_name(&rf.path, &root_name).into(),
142 max_row: rf.max_row,
143 excerpts: rf
144 .excerpts
145 .iter()
146 .map(|e| CapturedRelatedExcerpt {
147 row_range: e.row_range.clone(),
148 text: e.text.to_string(),
149 })
150 .collect(),
151 })
152 .collect();
153
154 CapturedPromptInput {
155 cursor_file_content: content,
156 cursor_offset: full_cursor_offset,
157 cursor_row: cursor_point.row,
158 cursor_column: cursor_point.column,
159 excerpt_start_row: Some(0),
160 events: captured_events,
161 related_files: captured_related_files,
162 }
163 });
164
165 let mut spec = ExampleSpec {
166 name: generate_timestamp_name(),
167 repository_url,
168 revision,
169 tags: Vec::new(),
170 reasoning: None,
171 uncommitted_diff,
172 cursor_path,
173 cursor_position: String::new(),
174 edit_history,
175 expected_patches,
176 rejected_patch,
177 captured_prompt_input: prompt_input,
178 telemetry: None,
179 human_feedback: Vec::new(),
180 rating: None,
181 };
182 spec.set_cursor_excerpt(
183 &cursor_excerpt,
184 cursor_offset_in_excerpt,
185 &line_comment_prefix,
186 );
187 Ok(spec)
188 }))
189}
190
191fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
192 path.strip_prefix(root_name).unwrap_or(path)
193}
194
195fn write_event_with_relative_paths(
196 output: &mut String,
197 event: &zeta_prompt::Event,
198 root_name: &str,
199) {
200 fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
201 for component in strip_root_name(path, root_name).components() {
202 output.push('/');
203 write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
204 }
205 }
206
207 let zeta_prompt::Event::BufferChange {
208 path,
209 old_path,
210 diff,
211 ..
212 } = event;
213
214 output.push_str("--- a");
215 write_relative_path(output, old_path.as_ref(), root_name);
216 output.push_str("\n+++ b");
217 write_relative_path(output, path.as_ref(), root_name);
218 output.push('\n');
219 output.push_str(diff);
220}
221
222fn compute_cursor_excerpt(
223 snapshot: &language::BufferSnapshot,
224 cursor_anchor: language::Anchor,
225) -> (String, usize, Range<Point>) {
226 use text::ToOffset as _;
227
228 let cursor_point = cursor_anchor.to_point(snapshot);
229 let (_editable_range, context_range) =
230 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
231 let context_start_offset = context_range.start.to_offset(snapshot);
232 let cursor_offset = cursor_anchor.to_offset(snapshot);
233 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
234 let excerpt = snapshot
235 .text_for_range(context_range.clone())
236 .collect::<String>();
237 (excerpt, cursor_offset_in_excerpt, context_range)
238}
239
240async fn collect_snapshots(
241 project: &Entity<Project>,
242 git_store: &Entity<project::git_store::GitStore>,
243 worktree_id: WorktreeId,
244 events: &[StoredEvent],
245 cx: &mut gpui::AsyncApp,
246) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
247 let mut snapshots_by_path = HashMap::default();
248 for stored_event in events {
249 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
250 if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
251 let project_path = project
252 .find_project_path(path, cx)
253 .filter(|path| path.worktree_id == worktree_id)?;
254 let relative_path: Arc<Path> = project_path.path.as_std_path().into();
255 Some((project_path, relative_path))
256 }) {
257 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
258 let buffer = project
259 .update(cx, |project, cx| {
260 project.open_buffer(project_path.clone(), cx)
261 })
262 .await?;
263 let diff = git_store
264 .update(cx, |git_store, cx| {
265 git_store.open_uncommitted_diff(buffer.clone(), cx)
266 })
267 .await?;
268 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
269 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
270 }
271 }
272 }
273 Ok(snapshots_by_path)
274}
275
276fn compute_uncommitted_diff(
277 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
278) -> String {
279 let mut uncommitted_diff = String::new();
280 for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
281 if let Some(head_text) = &diff_snapshot.base_text_string() {
282 let file_diff = language::unified_diff(head_text, &before_text.text());
283 if !file_diff.is_empty() {
284 let path_str = relative_path.to_string_lossy();
285 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
286 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
287 uncommitted_diff.push_str(&file_diff);
288 if !uncommitted_diff.ends_with('\n') {
289 uncommitted_diff.push('\n');
290 }
291 }
292 }
293 }
294 uncommitted_diff
295}
296
297fn generate_timestamp_name() -> String {
298 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
299 match format {
300 Ok(format) => {
301 let now = time::OffsetDateTime::now_local()
302 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
303 now.format(&format)
304 .unwrap_or_else(|_| "unknown-time".to_string())
305 }
306 Err(_) => "unknown-time".to_string(),
307 }
308}
309
310pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
311 let default_rate = if cx.is_staff() {
312 DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
313 } else {
314 DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
315 };
316 let capture_rate = language::language_settings::all_language_settings(None, cx)
317 .edit_predictions
318 .example_capture_rate
319 .unwrap_or(default_rate);
320 cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
321 && rand::random::<u16>() % 10_000 < capture_rate
322}
323
324pub(crate) fn should_send_testing_zeta2_request() -> bool {
325 rand::random::<u16>() % 10_000 < ZETA2_TESTING_RATE_PER_10K_PREDICTION
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use crate::EditPredictionStore;
332 use client::{Client, UserStore};
333 use clock::FakeSystemClock;
334 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
335 use indoc::indoc;
336 use language::{Anchor, Point};
337 use project::{FakeFs, Project};
338 use serde_json::json;
339 use settings::SettingsStore;
340 use std::path::Path;
341
342 #[gpui::test]
343 async fn test_capture_example(cx: &mut TestAppContext) {
344 init_test(cx);
345 let fs = FakeFs::new(cx.executor());
346
347 let committed_contents = indoc! {"
348 fn main() {
349 one();
350 two();
351 three();
352 four();
353 five();
354 six();
355 seven();
356 eight();
357 nine();
358 }
359 "};
360
361 let disk_contents = indoc! {"
362 fn main() {
363 // comment 1
364 one();
365 two();
366 three();
367 four();
368 five();
369 six();
370 seven();
371 eight();
372 // comment 2
373 nine();
374 }
375 "};
376
377 fs.insert_tree(
378 "/project",
379 json!({
380 ".git": {},
381 "src": {
382 "main.rs": disk_contents,
383 }
384 }),
385 )
386 .await;
387
388 // Create an external file outside the main project
389 fs.insert_tree(
390 "/external",
391 json!({
392 "external.rs": "fn external() {}\n",
393 }),
394 )
395 .await;
396
397 fs.set_head_for_repo(
398 Path::new("/project/.git"),
399 &[("src/main.rs", committed_contents.to_string())],
400 "abc123def456",
401 );
402 fs.set_remote_for_repo(
403 Path::new("/project/.git"),
404 "origin",
405 "https://github.com/test/repo.git",
406 );
407
408 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
409
410 let buffer = project
411 .update(cx, |project, cx| {
412 project.open_local_buffer("/project/src/main.rs", cx)
413 })
414 .await
415 .unwrap();
416
417 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
418 ep_store.update(cx, |ep_store, cx| {
419 ep_store.register_buffer(&buffer, &project, cx)
420 });
421 cx.run_until_parked();
422
423 buffer.update(cx, |buffer, cx| {
424 let point = Point::new(6, 0);
425 buffer.edit([(point..point, " // comment 3\n")], None, cx);
426 let point = Point::new(4, 0);
427 buffer.edit([(point..point, " // comment 4\n")], None, cx);
428
429 pretty_assertions::assert_eq!(
430 buffer.text(),
431 indoc! {"
432 fn main() {
433 // comment 1
434 one();
435 two();
436 // comment 4
437 three();
438 four();
439 // comment 3
440 five();
441 six();
442 seven();
443 eight();
444 // comment 2
445 nine();
446 }
447 "}
448 );
449 });
450 cx.run_until_parked();
451
452 // Open and edit an external file (outside the main project's worktree)
453 let external_buffer = project
454 .update(cx, |project, cx| {
455 project.open_local_buffer("/external/external.rs", cx)
456 })
457 .await
458 .unwrap();
459 ep_store.update(cx, |ep_store, cx| {
460 ep_store.register_buffer(&external_buffer, &project, cx)
461 });
462 cx.run_until_parked();
463 external_buffer.update(cx, |buffer, cx| {
464 let point = Point::new(0, 0);
465 buffer.edit([(point..point, "// external edit\n")], None, cx);
466 });
467 cx.run_until_parked();
468
469 // Verify the external edit was recorded in events
470 let events = ep_store.update(cx, |store, cx| {
471 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
472 });
473 assert!(
474 matches!(
475 events
476 .last()
477 .unwrap()
478 .event
479 .as_ref(),
480 zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
481 ),
482 "external file edit should be in events"
483 );
484
485 let mut example = cx
486 .update(|cx| {
487 capture_example(
488 project.clone(),
489 buffer.clone(),
490 Anchor::MIN,
491 events,
492 Vec::new(),
493 true,
494 cx,
495 )
496 .unwrap()
497 })
498 .await
499 .unwrap();
500 example.name = "test".to_string();
501
502 pretty_assertions::assert_eq!(
503 example,
504 ExampleSpec {
505 name: "test".to_string(),
506 repository_url: "https://github.com/test/repo.git".to_string(),
507 revision: "abc123def456".to_string(),
508 tags: Vec::new(),
509 reasoning: None,
510 uncommitted_diff: indoc! {"
511 --- a/src/main.rs
512 +++ b/src/main.rs
513 @@ -1,4 +1,5 @@
514 fn main() {
515 + // comment 1
516 one();
517 two();
518 three();
519 @@ -7,5 +8,6 @@
520 six();
521 seven();
522 eight();
523 + // comment 2
524 nine();
525 }
526 "}
527 .to_string(),
528 cursor_path: Path::new("src/main.rs").into(),
529 cursor_position: indoc! {"
530 fn main() {
531 ^[CURSOR_POSITION]
532 // comment 1
533 one();
534 two();
535 // comment 4
536 three();
537 four();
538 // comment 3
539 five();
540 six();
541 seven();
542 eight();
543 // comment 2
544 nine();
545 }
546 "}
547 .to_string(),
548 edit_history: indoc! {"
549 --- a/src/main.rs
550 +++ b/src/main.rs
551 @@ -2,8 +2,10 @@
552 // comment 1
553 one();
554 two();
555 + // comment 4
556 three();
557 four();
558 + // comment 3
559 five();
560 six();
561 seven();
562 "}
563 .to_string(),
564 expected_patches: vec![
565 indoc! {"
566 --- a/src/main.rs
567 +++ b/src/main.rs
568 @@ -1,16 +1,16 @@
569 fn main() {
570 // comment 1
571 one();
572 two();
573 // comment 4
574 three();
575 four();
576 // comment 3
577 five();
578 six();
579 seven();
580 eight();
581 // comment 2
582 nine();
583 }
584 "}
585 .to_string()
586 ],
587 rejected_patch: Some(
588 indoc! {"
589 --- a/src/main.rs
590 +++ b/src/main.rs
591 @@ -1,16 +1,16 @@
592 fn main() {
593 // comment 1
594 one();
595 two();
596 // comment 4
597 three();
598 four();
599 // comment 3
600 five();
601 six();
602 seven();
603 eight();
604 // comment 2
605 nine();
606 }
607 "}
608 .to_string()
609 ),
610 captured_prompt_input: example.captured_prompt_input.clone(),
611 telemetry: None,
612 human_feedback: Vec::new(),
613 rating: None,
614 }
615 );
616
617 let prompt_input = example
618 .captured_prompt_input
619 .expect("should have captured prompt input");
620 assert!(
621 prompt_input.cursor_file_content.contains("fn main()"),
622 "cursor_file_content should contain file content"
623 );
624 assert_eq!(
625 prompt_input.cursor_offset, 0,
626 "cursor at Anchor::MIN should be offset 0"
627 );
628 assert_eq!(
629 prompt_input.cursor_row, 0,
630 "cursor at Anchor::MIN should be row 0"
631 );
632 assert_eq!(
633 prompt_input.cursor_column, 0,
634 "cursor at Anchor::MIN should be column 0"
635 );
636 assert!(prompt_input.events.len() > 0, "should have captured events");
637 assert_eq!(
638 prompt_input.related_files.len(),
639 0,
640 "should have no related files (none passed)"
641 );
642 }
643
644 fn init_test(cx: &mut TestAppContext) {
645 cx.update(|cx| {
646 let settings_store = SettingsStore::test(cx);
647 cx.set_global(settings_store);
648 zlog::init_test();
649 let http_client = FakeHttpClient::with_404_response();
650 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
651 language_model::init(client.clone(), cx);
652 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
653 EditPredictionStore::global(&client, &user_store, cx);
654 })
655 }
656}