1use crate::{
2 EditPredictionExampleCaptureFeatureFlag, StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position,
4 example_spec::{
5 CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile,
6 ExampleSpec, MAX_CURSOR_FILE_SIZE,
7 },
8};
9use anyhow::Result;
10use buffer_diff::BufferDiffSnapshot;
11use collections::HashMap;
12use feature_flags::FeatureFlagAppExt as _;
13use gpui::{App, Entity, Task};
14use language::{Buffer, ToPoint as _};
15use project::{Project, WorktreeId};
16use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
17use text::{BufferSnapshot as TextBufferSnapshot, Point, ToOffset as _};
18
19pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
20pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
21
22pub fn capture_example(
23 project: Entity<Project>,
24 buffer: Entity<Buffer>,
25 cursor_anchor: language::Anchor,
26 mut events: Vec<StoredEvent>,
27 related_files: Vec<zeta_prompt::RelatedFile>,
28 populate_expected_patch: bool,
29 cx: &mut App,
30) -> Option<Task<Result<ExampleSpec>>> {
31 let snapshot = buffer.read(cx).snapshot();
32 let file = snapshot.file()?;
33 let worktree_id = file.worktree_id(cx);
34 let repository = project.read(cx).active_repository(cx)?;
35 let repository_snapshot = repository.read(cx).snapshot();
36 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
37 let root_name = worktree.read(cx).root_name_str().to_owned();
38 let cursor_path: Arc<Path> = file.path().as_std_path().into();
39 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
40 return None;
41 }
42
43 let repository_url = repository_snapshot
44 .remote_origin_url
45 .clone()
46 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
47 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
48
49 let git_store = project.read(cx).git_store().clone();
50
51 Some(cx.spawn(async move |mut cx| {
52 let snapshots_by_path =
53 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
54
55 events.retain(|stored_event| {
56 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
57 let relative_path = strip_root_name(path, &root_name);
58 snapshots_by_path.contains_key(relative_path)
59 });
60
61 let line_comment_prefix = snapshot
62 .language()
63 .and_then(|lang| lang.config().line_comments.first())
64 .map(|s| s.to_string())
65 .unwrap_or_default();
66
67 let full_cursor_offset = cursor_anchor.to_offset(&snapshot);
68 let cursor_point = cursor_anchor.to_point(&snapshot);
69 let cursor_file_content = if snapshot.len() <= MAX_CURSOR_FILE_SIZE {
70 Some(snapshot.text())
71 } else {
72 None
73 };
74
75 let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
76 .background_executor()
77 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
78 .await;
79 let uncommitted_diff = cx
80 .background_executor()
81 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
82 .await;
83
84 let mut edit_history = String::new();
85 for stored_event in &events {
86 write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
87 if !edit_history.ends_with('\n') {
88 edit_history.push('\n');
89 }
90 }
91
92 // Initialize an empty patch with context lines, to make it easy
93 // to write the expected patch by hand.
94 let mut expected_patches = Vec::new();
95 let mut rejected_patch = None;
96 if populate_expected_patch {
97 let mut empty_patch = String::new();
98 let start_row = cursor_excerpt_range.start.row + 1;
99 let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
100 writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
101 writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
102 writeln!(
103 &mut empty_patch,
104 "@@ -{},{} +{},{} @@",
105 start_row, row_count, start_row, row_count,
106 )
107 .ok();
108 for line in cursor_excerpt.lines() {
109 writeln!(&mut empty_patch, " {}", line).ok();
110 }
111
112 expected_patches.push(empty_patch.clone());
113 rejected_patch = Some(empty_patch);
114 }
115
116 let prompt_input = cursor_file_content.map(|content| {
117 let captured_events: Vec<CapturedEvent> = events
118 .iter()
119 .map(|stored_event| {
120 let zeta_prompt::Event::BufferChange {
121 path,
122 old_path,
123 diff,
124 predicted,
125 in_open_source_repo,
126 } = stored_event.event.as_ref();
127 CapturedEvent {
128 path: strip_root_name(path, &root_name).into(),
129 old_path: strip_root_name(old_path, &root_name).into(),
130 diff: diff.clone(),
131 predicted: *predicted,
132 in_open_source_repo: *in_open_source_repo,
133 }
134 })
135 .collect();
136
137 let captured_related_files: Vec<CapturedRelatedFile> = related_files
138 .iter()
139 .map(|rf| CapturedRelatedFile {
140 path: strip_root_name(&rf.path, &root_name).into(),
141 max_row: rf.max_row,
142 excerpts: rf
143 .excerpts
144 .iter()
145 .map(|e| CapturedRelatedExcerpt {
146 row_range: e.row_range.clone(),
147 text: e.text.to_string(),
148 })
149 .collect(),
150 })
151 .collect();
152
153 CapturedPromptInput {
154 cursor_file_content: content,
155 cursor_offset: full_cursor_offset,
156 cursor_row: cursor_point.row,
157 cursor_column: cursor_point.column,
158 events: captured_events,
159 related_files: captured_related_files,
160 }
161 });
162
163 let mut spec = ExampleSpec {
164 name: generate_timestamp_name(),
165 repository_url,
166 revision,
167 tags: Vec::new(),
168 reasoning: None,
169 uncommitted_diff,
170 cursor_path,
171 cursor_position: String::new(),
172 edit_history,
173 expected_patches,
174 rejected_patch,
175 captured_prompt_input: prompt_input,
176 telemetry: None,
177 };
178 spec.set_cursor_excerpt(
179 &cursor_excerpt,
180 cursor_offset_in_excerpt,
181 &line_comment_prefix,
182 );
183 Ok(spec)
184 }))
185}
186
187fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
188 path.strip_prefix(root_name).unwrap_or(path)
189}
190
191fn write_event_with_relative_paths(
192 output: &mut String,
193 event: &zeta_prompt::Event,
194 root_name: &str,
195) {
196 fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
197 for component in strip_root_name(path, root_name).components() {
198 output.push('/');
199 write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
200 }
201 }
202
203 let zeta_prompt::Event::BufferChange {
204 path,
205 old_path,
206 diff,
207 ..
208 } = event;
209
210 output.push_str("--- a");
211 write_relative_path(output, old_path.as_ref(), root_name);
212 output.push_str("\n+++ b");
213 write_relative_path(output, path.as_ref(), root_name);
214 output.push('\n');
215 output.push_str(diff);
216}
217
218fn compute_cursor_excerpt(
219 snapshot: &language::BufferSnapshot,
220 cursor_anchor: language::Anchor,
221) -> (String, usize, Range<Point>) {
222 use text::ToOffset as _;
223
224 let cursor_point = cursor_anchor.to_point(snapshot);
225 let (_editable_range, context_range) =
226 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
227 let context_start_offset = context_range.start.to_offset(snapshot);
228 let cursor_offset = cursor_anchor.to_offset(snapshot);
229 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
230 let excerpt = snapshot
231 .text_for_range(context_range.clone())
232 .collect::<String>();
233 (excerpt, cursor_offset_in_excerpt, context_range)
234}
235
236async fn collect_snapshots(
237 project: &Entity<Project>,
238 git_store: &Entity<project::git_store::GitStore>,
239 worktree_id: WorktreeId,
240 events: &[StoredEvent],
241 cx: &mut gpui::AsyncApp,
242) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
243 let mut snapshots_by_path = HashMap::default();
244 for stored_event in events {
245 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
246 if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
247 let project_path = project
248 .find_project_path(path, cx)
249 .filter(|path| path.worktree_id == worktree_id)?;
250 let relative_path: Arc<Path> = project_path.path.as_std_path().into();
251 Some((project_path, relative_path))
252 }) {
253 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
254 let buffer = project
255 .update(cx, |project, cx| {
256 project.open_buffer(project_path.clone(), cx)
257 })
258 .await?;
259 let diff = git_store
260 .update(cx, |git_store, cx| {
261 git_store.open_uncommitted_diff(buffer.clone(), cx)
262 })
263 .await?;
264 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
265 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
266 }
267 }
268 }
269 Ok(snapshots_by_path)
270}
271
272fn compute_uncommitted_diff(
273 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
274) -> String {
275 let mut uncommitted_diff = String::new();
276 for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
277 if let Some(head_text) = &diff_snapshot.base_text_string() {
278 let file_diff = language::unified_diff(head_text, &before_text.text());
279 if !file_diff.is_empty() {
280 let path_str = relative_path.to_string_lossy();
281 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
282 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
283 uncommitted_diff.push_str(&file_diff);
284 if !uncommitted_diff.ends_with('\n') {
285 uncommitted_diff.push('\n');
286 }
287 }
288 }
289 }
290 uncommitted_diff
291}
292
293fn generate_timestamp_name() -> String {
294 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
295 match format {
296 Ok(format) => {
297 let now = time::OffsetDateTime::now_local()
298 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
299 now.format(&format)
300 .unwrap_or_else(|_| "unknown-time".to_string())
301 }
302 Err(_) => "unknown-time".to_string(),
303 }
304}
305
306pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
307 let default_rate = if cx.is_staff() {
308 DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
309 } else {
310 DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
311 };
312 let capture_rate = language::language_settings::all_language_settings(None, cx)
313 .edit_predictions
314 .example_capture_rate
315 .unwrap_or(default_rate);
316 cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
317 && rand::random::<u16>() % 10_000 < capture_rate
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::EditPredictionStore;
324 use client::{Client, UserStore};
325 use clock::FakeSystemClock;
326 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
327 use indoc::indoc;
328 use language::{Anchor, Point};
329 use project::{FakeFs, Project};
330 use serde_json::json;
331 use settings::SettingsStore;
332 use std::path::Path;
333
334 #[gpui::test]
335 async fn test_capture_example(cx: &mut TestAppContext) {
336 init_test(cx);
337 let fs = FakeFs::new(cx.executor());
338
339 let committed_contents = indoc! {"
340 fn main() {
341 one();
342 two();
343 three();
344 four();
345 five();
346 six();
347 seven();
348 eight();
349 nine();
350 }
351 "};
352
353 let disk_contents = indoc! {"
354 fn main() {
355 // comment 1
356 one();
357 two();
358 three();
359 four();
360 five();
361 six();
362 seven();
363 eight();
364 // comment 2
365 nine();
366 }
367 "};
368
369 fs.insert_tree(
370 "/project",
371 json!({
372 ".git": {},
373 "src": {
374 "main.rs": disk_contents,
375 }
376 }),
377 )
378 .await;
379
380 // Create an external file outside the main project
381 fs.insert_tree(
382 "/external",
383 json!({
384 "external.rs": "fn external() {}\n",
385 }),
386 )
387 .await;
388
389 fs.set_head_for_repo(
390 Path::new("/project/.git"),
391 &[("src/main.rs", committed_contents.to_string())],
392 "abc123def456",
393 );
394 fs.set_remote_for_repo(
395 Path::new("/project/.git"),
396 "origin",
397 "https://github.com/test/repo.git",
398 );
399
400 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
401
402 let buffer = project
403 .update(cx, |project, cx| {
404 project.open_local_buffer("/project/src/main.rs", cx)
405 })
406 .await
407 .unwrap();
408
409 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
410 ep_store.update(cx, |ep_store, cx| {
411 ep_store.register_buffer(&buffer, &project, cx)
412 });
413 cx.run_until_parked();
414
415 buffer.update(cx, |buffer, cx| {
416 let point = Point::new(6, 0);
417 buffer.edit([(point..point, " // comment 3\n")], None, cx);
418 let point = Point::new(4, 0);
419 buffer.edit([(point..point, " // comment 4\n")], None, cx);
420
421 pretty_assertions::assert_eq!(
422 buffer.text(),
423 indoc! {"
424 fn main() {
425 // comment 1
426 one();
427 two();
428 // comment 4
429 three();
430 four();
431 // comment 3
432 five();
433 six();
434 seven();
435 eight();
436 // comment 2
437 nine();
438 }
439 "}
440 );
441 });
442 cx.run_until_parked();
443
444 // Open and edit an external file (outside the main project's worktree)
445 let external_buffer = project
446 .update(cx, |project, cx| {
447 project.open_local_buffer("/external/external.rs", cx)
448 })
449 .await
450 .unwrap();
451 ep_store.update(cx, |ep_store, cx| {
452 ep_store.register_buffer(&external_buffer, &project, cx)
453 });
454 cx.run_until_parked();
455 external_buffer.update(cx, |buffer, cx| {
456 let point = Point::new(0, 0);
457 buffer.edit([(point..point, "// external edit\n")], None, cx);
458 });
459 cx.run_until_parked();
460
461 // Verify the external edit was recorded in events
462 let events = ep_store.update(cx, |store, cx| {
463 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
464 });
465 assert!(
466 matches!(
467 events
468 .last()
469 .unwrap()
470 .event
471 .as_ref(),
472 zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
473 ),
474 "external file edit should be in events"
475 );
476
477 let mut example = cx
478 .update(|cx| {
479 capture_example(
480 project.clone(),
481 buffer.clone(),
482 Anchor::MIN,
483 events,
484 Vec::new(),
485 true,
486 cx,
487 )
488 .unwrap()
489 })
490 .await
491 .unwrap();
492 example.name = "test".to_string();
493
494 pretty_assertions::assert_eq!(
495 example,
496 ExampleSpec {
497 name: "test".to_string(),
498 repository_url: "https://github.com/test/repo.git".to_string(),
499 revision: "abc123def456".to_string(),
500 tags: Vec::new(),
501 reasoning: None,
502 uncommitted_diff: indoc! {"
503 --- a/src/main.rs
504 +++ b/src/main.rs
505 @@ -1,4 +1,5 @@
506 fn main() {
507 + // comment 1
508 one();
509 two();
510 three();
511 @@ -7,5 +8,6 @@
512 six();
513 seven();
514 eight();
515 + // comment 2
516 nine();
517 }
518 "}
519 .to_string(),
520 cursor_path: Path::new("src/main.rs").into(),
521 cursor_position: indoc! {"
522 fn main() {
523 ^[CURSOR_POSITION]
524 // comment 1
525 one();
526 two();
527 // comment 4
528 three();
529 four();
530 // comment 3
531 five();
532 six();
533 seven();
534 eight();
535 // comment 2
536 nine();
537 }
538 "}
539 .to_string(),
540 edit_history: indoc! {"
541 --- a/src/main.rs
542 +++ b/src/main.rs
543 @@ -2,8 +2,10 @@
544 // comment 1
545 one();
546 two();
547 + // comment 4
548 three();
549 four();
550 + // comment 3
551 five();
552 six();
553 seven();
554 "}
555 .to_string(),
556 expected_patches: vec![
557 indoc! {"
558 --- a/src/main.rs
559 +++ b/src/main.rs
560 @@ -1,16 +1,16 @@
561 fn main() {
562 // comment 1
563 one();
564 two();
565 // comment 4
566 three();
567 four();
568 // comment 3
569 five();
570 six();
571 seven();
572 eight();
573 // comment 2
574 nine();
575 }
576 "}
577 .to_string()
578 ],
579 rejected_patch: Some(
580 indoc! {"
581 --- a/src/main.rs
582 +++ b/src/main.rs
583 @@ -1,16 +1,16 @@
584 fn main() {
585 // comment 1
586 one();
587 two();
588 // comment 4
589 three();
590 four();
591 // comment 3
592 five();
593 six();
594 seven();
595 eight();
596 // comment 2
597 nine();
598 }
599 "}
600 .to_string()
601 ),
602 captured_prompt_input: example.captured_prompt_input.clone(),
603 telemetry: None,
604 }
605 );
606
607 let prompt_input = example
608 .captured_prompt_input
609 .expect("should have captured prompt input");
610 assert!(
611 prompt_input.cursor_file_content.contains("fn main()"),
612 "cursor_file_content should contain file content"
613 );
614 assert_eq!(
615 prompt_input.cursor_offset, 0,
616 "cursor at Anchor::MIN should be offset 0"
617 );
618 assert_eq!(
619 prompt_input.cursor_row, 0,
620 "cursor at Anchor::MIN should be row 0"
621 );
622 assert_eq!(
623 prompt_input.cursor_column, 0,
624 "cursor at Anchor::MIN should be column 0"
625 );
626 assert!(prompt_input.events.len() > 0, "should have captured events");
627 assert_eq!(
628 prompt_input.related_files.len(),
629 0,
630 "should have no related files (none passed)"
631 );
632 }
633
634 fn init_test(cx: &mut TestAppContext) {
635 cx.update(|cx| {
636 let settings_store = SettingsStore::test(cx);
637 cx.set_global(settings_store);
638 zlog::init_test();
639 let http_client = FakeHttpClient::with_404_response();
640 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
641 language_model::init(client.clone(), cx);
642 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
643 EditPredictionStore::global(&client, &user_store, cx);
644 })
645 }
646}