1use crate::{
2 EditPredictionExampleCaptureFeatureFlag, StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position,
4 example_spec::{
5 CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile,
6 ExampleSpec, MAX_CURSOR_FILE_SIZE,
7 },
8};
9use anyhow::Result;
10use buffer_diff::BufferDiffSnapshot;
11use collections::HashMap;
12use feature_flags::FeatureFlagAppExt as _;
13use gpui::{App, Entity, Task};
14use language::{Buffer, ToPoint as _};
15use project::{Project, WorktreeId};
16use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
17use text::{BufferSnapshot as TextBufferSnapshot, Point, ToOffset as _};
18
19pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
20pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
21pub(crate) const ZETA2_TESTING_RATE_PER_10K_PREDICTION: u16 = 500;
22
23pub fn capture_example(
24 project: Entity<Project>,
25 buffer: Entity<Buffer>,
26 cursor_anchor: language::Anchor,
27 mut events: Vec<StoredEvent>,
28 related_files: Vec<zeta_prompt::RelatedFile>,
29 populate_expected_patch: bool,
30 cx: &mut App,
31) -> Option<Task<Result<ExampleSpec>>> {
32 let snapshot = buffer.read(cx).snapshot();
33 let file = snapshot.file()?;
34 let worktree_id = file.worktree_id(cx);
35 let repository = project.read(cx).active_repository(cx)?;
36 let repository_snapshot = repository.read(cx).snapshot();
37 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
38 let root_name = worktree.read(cx).root_name_str().to_owned();
39 let cursor_path: Arc<Path> = file.path().as_std_path().into();
40 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
41 return None;
42 }
43
44 let repository_url = repository_snapshot
45 .remote_origin_url
46 .clone()
47 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
48 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
49
50 let git_store = project.read(cx).git_store().clone();
51
52 Some(cx.spawn(async move |mut cx| {
53 let snapshots_by_path =
54 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
55
56 events.retain(|stored_event| {
57 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
58 let relative_path = strip_root_name(path, &root_name);
59 snapshots_by_path.contains_key(relative_path)
60 });
61
62 let line_comment_prefix = snapshot
63 .language()
64 .and_then(|lang| lang.config().line_comments.first())
65 .map(|s| s.to_string())
66 .unwrap_or_default();
67
68 let full_cursor_offset = cursor_anchor.to_offset(&snapshot);
69 let cursor_point = cursor_anchor.to_point(&snapshot);
70 let cursor_file_content = if snapshot.len() <= MAX_CURSOR_FILE_SIZE {
71 Some(snapshot.text())
72 } else {
73 None
74 };
75
76 let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
77 .background_executor()
78 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
79 .await;
80 let uncommitted_diff = cx
81 .background_executor()
82 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
83 .await;
84
85 let mut edit_history = String::new();
86 for stored_event in &events {
87 write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
88 if !edit_history.ends_with('\n') {
89 edit_history.push('\n');
90 }
91 }
92
93 // Initialize an empty patch with context lines, to make it easy
94 // to write the expected patch by hand.
95 let mut expected_patches = Vec::new();
96 let mut rejected_patch = None;
97 if populate_expected_patch {
98 let mut empty_patch = String::new();
99 let start_row = cursor_excerpt_range.start.row + 1;
100 let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
101 writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
102 writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
103 writeln!(
104 &mut empty_patch,
105 "@@ -{},{} +{},{} @@",
106 start_row, row_count, start_row, row_count,
107 )
108 .ok();
109 for line in cursor_excerpt.lines() {
110 writeln!(&mut empty_patch, " {}", line).ok();
111 }
112
113 expected_patches.push(empty_patch.clone());
114 rejected_patch = Some(empty_patch);
115 }
116
117 let prompt_input = cursor_file_content.map(|content| {
118 let captured_events: Vec<CapturedEvent> = events
119 .iter()
120 .map(|stored_event| {
121 let zeta_prompt::Event::BufferChange {
122 path,
123 old_path,
124 diff,
125 predicted,
126 in_open_source_repo,
127 } = stored_event.event.as_ref();
128 CapturedEvent {
129 path: strip_root_name(path, &root_name).into(),
130 old_path: strip_root_name(old_path, &root_name).into(),
131 diff: diff.clone(),
132 predicted: *predicted,
133 in_open_source_repo: *in_open_source_repo,
134 }
135 })
136 .collect();
137
138 let captured_related_files: Vec<CapturedRelatedFile> = related_files
139 .iter()
140 .map(|rf| CapturedRelatedFile {
141 path: strip_root_name(&rf.path, &root_name).into(),
142 max_row: rf.max_row,
143 excerpts: rf
144 .excerpts
145 .iter()
146 .map(|e| CapturedRelatedExcerpt {
147 row_range: e.row_range.clone(),
148 text: e.text.to_string(),
149 })
150 .collect(),
151 })
152 .collect();
153
154 CapturedPromptInput {
155 cursor_file_content: content,
156 cursor_offset: full_cursor_offset,
157 cursor_row: cursor_point.row,
158 cursor_column: cursor_point.column,
159 events: captured_events,
160 related_files: captured_related_files,
161 }
162 });
163
164 let mut spec = ExampleSpec {
165 name: generate_timestamp_name(),
166 repository_url,
167 revision,
168 tags: Vec::new(),
169 reasoning: None,
170 uncommitted_diff,
171 cursor_path,
172 cursor_position: String::new(),
173 edit_history,
174 expected_patches,
175 rejected_patch,
176 captured_prompt_input: prompt_input,
177 telemetry: None,
178 human_feedback: Vec::new(),
179 rating: None,
180 };
181 spec.set_cursor_excerpt(
182 &cursor_excerpt,
183 cursor_offset_in_excerpt,
184 &line_comment_prefix,
185 );
186 Ok(spec)
187 }))
188}
189
190fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
191 path.strip_prefix(root_name).unwrap_or(path)
192}
193
194fn write_event_with_relative_paths(
195 output: &mut String,
196 event: &zeta_prompt::Event,
197 root_name: &str,
198) {
199 fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
200 for component in strip_root_name(path, root_name).components() {
201 output.push('/');
202 write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
203 }
204 }
205
206 let zeta_prompt::Event::BufferChange {
207 path,
208 old_path,
209 diff,
210 ..
211 } = event;
212
213 output.push_str("--- a");
214 write_relative_path(output, old_path.as_ref(), root_name);
215 output.push_str("\n+++ b");
216 write_relative_path(output, path.as_ref(), root_name);
217 output.push('\n');
218 output.push_str(diff);
219}
220
221fn compute_cursor_excerpt(
222 snapshot: &language::BufferSnapshot,
223 cursor_anchor: language::Anchor,
224) -> (String, usize, Range<Point>) {
225 use text::ToOffset as _;
226
227 let cursor_point = cursor_anchor.to_point(snapshot);
228 let (_editable_range, context_range) =
229 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
230 let context_start_offset = context_range.start.to_offset(snapshot);
231 let cursor_offset = cursor_anchor.to_offset(snapshot);
232 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
233 let excerpt = snapshot
234 .text_for_range(context_range.clone())
235 .collect::<String>();
236 (excerpt, cursor_offset_in_excerpt, context_range)
237}
238
239async fn collect_snapshots(
240 project: &Entity<Project>,
241 git_store: &Entity<project::git_store::GitStore>,
242 worktree_id: WorktreeId,
243 events: &[StoredEvent],
244 cx: &mut gpui::AsyncApp,
245) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
246 let mut snapshots_by_path = HashMap::default();
247 for stored_event in events {
248 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
249 if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
250 let project_path = project
251 .find_project_path(path, cx)
252 .filter(|path| path.worktree_id == worktree_id)?;
253 let relative_path: Arc<Path> = project_path.path.as_std_path().into();
254 Some((project_path, relative_path))
255 }) {
256 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
257 let buffer = project
258 .update(cx, |project, cx| {
259 project.open_buffer(project_path.clone(), cx)
260 })
261 .await?;
262 let diff = git_store
263 .update(cx, |git_store, cx| {
264 git_store.open_uncommitted_diff(buffer.clone(), cx)
265 })
266 .await?;
267 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
268 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
269 }
270 }
271 }
272 Ok(snapshots_by_path)
273}
274
275fn compute_uncommitted_diff(
276 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
277) -> String {
278 let mut uncommitted_diff = String::new();
279 for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
280 if let Some(head_text) = &diff_snapshot.base_text_string() {
281 let file_diff = language::unified_diff(head_text, &before_text.text());
282 if !file_diff.is_empty() {
283 let path_str = relative_path.to_string_lossy();
284 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
285 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
286 uncommitted_diff.push_str(&file_diff);
287 if !uncommitted_diff.ends_with('\n') {
288 uncommitted_diff.push('\n');
289 }
290 }
291 }
292 }
293 uncommitted_diff
294}
295
296fn generate_timestamp_name() -> String {
297 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
298 match format {
299 Ok(format) => {
300 let now = time::OffsetDateTime::now_local()
301 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
302 now.format(&format)
303 .unwrap_or_else(|_| "unknown-time".to_string())
304 }
305 Err(_) => "unknown-time".to_string(),
306 }
307}
308
309pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
310 let default_rate = if cx.is_staff() {
311 DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
312 } else {
313 DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
314 };
315 let capture_rate = language::language_settings::all_language_settings(None, cx)
316 .edit_predictions
317 .example_capture_rate
318 .unwrap_or(default_rate);
319 cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
320 && rand::random::<u16>() % 10_000 < capture_rate
321}
322
323pub(crate) fn should_send_testing_zeta2_request() -> bool {
324 rand::random::<u16>() % 10_000 < ZETA2_TESTING_RATE_PER_10K_PREDICTION
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use crate::EditPredictionStore;
331 use client::{Client, UserStore};
332 use clock::FakeSystemClock;
333 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
334 use indoc::indoc;
335 use language::{Anchor, Point};
336 use project::{FakeFs, Project};
337 use serde_json::json;
338 use settings::SettingsStore;
339 use std::path::Path;
340
341 #[gpui::test]
342 async fn test_capture_example(cx: &mut TestAppContext) {
343 init_test(cx);
344 let fs = FakeFs::new(cx.executor());
345
346 let committed_contents = indoc! {"
347 fn main() {
348 one();
349 two();
350 three();
351 four();
352 five();
353 six();
354 seven();
355 eight();
356 nine();
357 }
358 "};
359
360 let disk_contents = indoc! {"
361 fn main() {
362 // comment 1
363 one();
364 two();
365 three();
366 four();
367 five();
368 six();
369 seven();
370 eight();
371 // comment 2
372 nine();
373 }
374 "};
375
376 fs.insert_tree(
377 "/project",
378 json!({
379 ".git": {},
380 "src": {
381 "main.rs": disk_contents,
382 }
383 }),
384 )
385 .await;
386
387 // Create an external file outside the main project
388 fs.insert_tree(
389 "/external",
390 json!({
391 "external.rs": "fn external() {}\n",
392 }),
393 )
394 .await;
395
396 fs.set_head_for_repo(
397 Path::new("/project/.git"),
398 &[("src/main.rs", committed_contents.to_string())],
399 "abc123def456",
400 );
401 fs.set_remote_for_repo(
402 Path::new("/project/.git"),
403 "origin",
404 "https://github.com/test/repo.git",
405 );
406
407 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
408
409 let buffer = project
410 .update(cx, |project, cx| {
411 project.open_local_buffer("/project/src/main.rs", cx)
412 })
413 .await
414 .unwrap();
415
416 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
417 ep_store.update(cx, |ep_store, cx| {
418 ep_store.register_buffer(&buffer, &project, cx)
419 });
420 cx.run_until_parked();
421
422 buffer.update(cx, |buffer, cx| {
423 let point = Point::new(6, 0);
424 buffer.edit([(point..point, " // comment 3\n")], None, cx);
425 let point = Point::new(4, 0);
426 buffer.edit([(point..point, " // comment 4\n")], None, cx);
427
428 pretty_assertions::assert_eq!(
429 buffer.text(),
430 indoc! {"
431 fn main() {
432 // comment 1
433 one();
434 two();
435 // comment 4
436 three();
437 four();
438 // comment 3
439 five();
440 six();
441 seven();
442 eight();
443 // comment 2
444 nine();
445 }
446 "}
447 );
448 });
449 cx.run_until_parked();
450
451 // Open and edit an external file (outside the main project's worktree)
452 let external_buffer = project
453 .update(cx, |project, cx| {
454 project.open_local_buffer("/external/external.rs", cx)
455 })
456 .await
457 .unwrap();
458 ep_store.update(cx, |ep_store, cx| {
459 ep_store.register_buffer(&external_buffer, &project, cx)
460 });
461 cx.run_until_parked();
462 external_buffer.update(cx, |buffer, cx| {
463 let point = Point::new(0, 0);
464 buffer.edit([(point..point, "// external edit\n")], None, cx);
465 });
466 cx.run_until_parked();
467
468 // Verify the external edit was recorded in events
469 let events = ep_store.update(cx, |store, cx| {
470 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
471 });
472 assert!(
473 matches!(
474 events
475 .last()
476 .unwrap()
477 .event
478 .as_ref(),
479 zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
480 ),
481 "external file edit should be in events"
482 );
483
484 let mut example = cx
485 .update(|cx| {
486 capture_example(
487 project.clone(),
488 buffer.clone(),
489 Anchor::MIN,
490 events,
491 Vec::new(),
492 true,
493 cx,
494 )
495 .unwrap()
496 })
497 .await
498 .unwrap();
499 example.name = "test".to_string();
500
501 pretty_assertions::assert_eq!(
502 example,
503 ExampleSpec {
504 name: "test".to_string(),
505 repository_url: "https://github.com/test/repo.git".to_string(),
506 revision: "abc123def456".to_string(),
507 tags: Vec::new(),
508 reasoning: None,
509 uncommitted_diff: indoc! {"
510 --- a/src/main.rs
511 +++ b/src/main.rs
512 @@ -1,4 +1,5 @@
513 fn main() {
514 + // comment 1
515 one();
516 two();
517 three();
518 @@ -7,5 +8,6 @@
519 six();
520 seven();
521 eight();
522 + // comment 2
523 nine();
524 }
525 "}
526 .to_string(),
527 cursor_path: Path::new("src/main.rs").into(),
528 cursor_position: indoc! {"
529 fn main() {
530 ^[CURSOR_POSITION]
531 // comment 1
532 one();
533 two();
534 // comment 4
535 three();
536 four();
537 // comment 3
538 five();
539 six();
540 seven();
541 eight();
542 // comment 2
543 nine();
544 }
545 "}
546 .to_string(),
547 edit_history: indoc! {"
548 --- a/src/main.rs
549 +++ b/src/main.rs
550 @@ -2,8 +2,10 @@
551 // comment 1
552 one();
553 two();
554 + // comment 4
555 three();
556 four();
557 + // comment 3
558 five();
559 six();
560 seven();
561 "}
562 .to_string(),
563 expected_patches: vec![
564 indoc! {"
565 --- a/src/main.rs
566 +++ b/src/main.rs
567 @@ -1,16 +1,16 @@
568 fn main() {
569 // comment 1
570 one();
571 two();
572 // comment 4
573 three();
574 four();
575 // comment 3
576 five();
577 six();
578 seven();
579 eight();
580 // comment 2
581 nine();
582 }
583 "}
584 .to_string()
585 ],
586 rejected_patch: Some(
587 indoc! {"
588 --- a/src/main.rs
589 +++ b/src/main.rs
590 @@ -1,16 +1,16 @@
591 fn main() {
592 // comment 1
593 one();
594 two();
595 // comment 4
596 three();
597 four();
598 // comment 3
599 five();
600 six();
601 seven();
602 eight();
603 // comment 2
604 nine();
605 }
606 "}
607 .to_string()
608 ),
609 captured_prompt_input: example.captured_prompt_input.clone(),
610 telemetry: None,
611 human_feedback: Vec::new(),
612 rating: None,
613 }
614 );
615
616 let prompt_input = example
617 .captured_prompt_input
618 .expect("should have captured prompt input");
619 assert!(
620 prompt_input.cursor_file_content.contains("fn main()"),
621 "cursor_file_content should contain file content"
622 );
623 assert_eq!(
624 prompt_input.cursor_offset, 0,
625 "cursor at Anchor::MIN should be offset 0"
626 );
627 assert_eq!(
628 prompt_input.cursor_row, 0,
629 "cursor at Anchor::MIN should be row 0"
630 );
631 assert_eq!(
632 prompt_input.cursor_column, 0,
633 "cursor at Anchor::MIN should be column 0"
634 );
635 assert!(prompt_input.events.len() > 0, "should have captured events");
636 assert_eq!(
637 prompt_input.related_files.len(),
638 0,
639 "should have no related files (none passed)"
640 );
641 }
642
643 fn init_test(cx: &mut TestAppContext) {
644 cx.update(|cx| {
645 let settings_store = SettingsStore::test(cx);
646 cx.set_global(settings_store);
647 zlog::init_test();
648 let http_client = FakeHttpClient::with_404_response();
649 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
650 language_model::init(client.clone(), cx);
651 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
652 EditPredictionStore::global(&client, &user_store, cx);
653 })
654 }
655}