1use crate::{
2 EditPredictionExampleCaptureFeatureFlag, StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
4};
5use anyhow::Result;
6use buffer_diff::BufferDiffSnapshot;
7use collections::HashMap;
8use feature_flags::FeatureFlagAppExt as _;
9use gpui::{App, Entity, Task};
10use language::{Buffer, ToPoint as _};
11use project::{Project, WorktreeId};
12use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
13use text::{BufferSnapshot as TextBufferSnapshot, Point};
14
15pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
16pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
17
18pub fn capture_example(
19 project: Entity<Project>,
20 buffer: Entity<Buffer>,
21 cursor_anchor: language::Anchor,
22 mut events: Vec<StoredEvent>,
23 populate_expected_patch: bool,
24 cx: &mut App,
25) -> Option<Task<Result<ExampleSpec>>> {
26 let snapshot = buffer.read(cx).snapshot();
27 let file = snapshot.file()?;
28 let worktree_id = file.worktree_id(cx);
29 let repository = project.read(cx).active_repository(cx)?;
30 let repository_snapshot = repository.read(cx).snapshot();
31 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
32 let root_name = worktree.read(cx).root_name_str().to_owned();
33 let cursor_path: Arc<Path> = file.path().as_std_path().into();
34 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
35 return None;
36 }
37
38 let repository_url = repository_snapshot
39 .remote_origin_url
40 .clone()
41 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
42 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
43
44 let git_store = project.read(cx).git_store().clone();
45
46 Some(cx.spawn(async move |mut cx| {
47 let snapshots_by_path =
48 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
49
50 events.retain(|stored_event| {
51 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
52 let relative_path = strip_root_name(path, &root_name);
53 snapshots_by_path.contains_key(relative_path)
54 });
55
56 let line_comment_prefix = snapshot
57 .language()
58 .and_then(|lang| lang.config().line_comments.first())
59 .map(|s| s.to_string())
60 .unwrap_or_default();
61 let (cursor_excerpt, cursor_offset, cursor_excerpt_range) = cx
62 .background_executor()
63 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
64 .await;
65 let uncommitted_diff = cx
66 .background_executor()
67 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
68 .await;
69
70 let mut edit_history = String::new();
71 for stored_event in &events {
72 write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
73 if !edit_history.ends_with('\n') {
74 edit_history.push('\n');
75 }
76 }
77
78 // Initialize an empty patch with context lines, to make it easy
79 // to write the expected patch by hand.
80 let mut expected_patches = Vec::new();
81 let mut rejected_patch = None;
82 if populate_expected_patch {
83 let mut empty_patch = String::new();
84 let start_row = cursor_excerpt_range.start.row + 1;
85 let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
86 writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
87 writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
88 writeln!(
89 &mut empty_patch,
90 "@@ -{},{} +{},{} @@",
91 start_row, row_count, start_row, row_count,
92 )
93 .ok();
94 for line in cursor_excerpt.lines() {
95 writeln!(&mut empty_patch, " {}", line).ok();
96 }
97
98 expected_patches.push(empty_patch.clone());
99 rejected_patch = Some(empty_patch);
100 }
101
102 let mut spec = ExampleSpec {
103 name: generate_timestamp_name(),
104 repository_url,
105 revision,
106 tags: Vec::new(),
107 reasoning: None,
108 uncommitted_diff,
109 cursor_path,
110 cursor_position: String::new(),
111 edit_history,
112 expected_patches,
113 rejected_patch,
114 };
115 spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
116 Ok(spec)
117 }))
118}
119
120fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
121 path.strip_prefix(root_name).unwrap_or(path)
122}
123
124fn write_event_with_relative_paths(
125 output: &mut String,
126 event: &zeta_prompt::Event,
127 root_name: &str,
128) {
129 fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
130 for component in strip_root_name(path, root_name).components() {
131 output.push('/');
132 write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
133 }
134 }
135
136 let zeta_prompt::Event::BufferChange {
137 path,
138 old_path,
139 diff,
140 ..
141 } = event;
142
143 output.push_str("--- a");
144 write_relative_path(output, old_path.as_ref(), root_name);
145 output.push_str("\n+++ b");
146 write_relative_path(output, path.as_ref(), root_name);
147 output.push('\n');
148 output.push_str(diff);
149}
150
151fn compute_cursor_excerpt(
152 snapshot: &language::BufferSnapshot,
153 cursor_anchor: language::Anchor,
154) -> (String, usize, Range<Point>) {
155 use text::ToOffset as _;
156
157 let cursor_point = cursor_anchor.to_point(snapshot);
158 let (_editable_range, context_range) =
159 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
160 let context_start_offset = context_range.start.to_offset(snapshot);
161 let cursor_offset = cursor_anchor.to_offset(snapshot);
162 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
163 let excerpt = snapshot
164 .text_for_range(context_range.clone())
165 .collect::<String>();
166 (excerpt, cursor_offset_in_excerpt, context_range)
167}
168
169async fn collect_snapshots(
170 project: &Entity<Project>,
171 git_store: &Entity<project::git_store::GitStore>,
172 worktree_id: WorktreeId,
173 events: &[StoredEvent],
174 cx: &mut gpui::AsyncApp,
175) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
176 let mut snapshots_by_path = HashMap::default();
177 for stored_event in events {
178 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
179 if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
180 let project_path = project
181 .find_project_path(path, cx)
182 .filter(|path| path.worktree_id == worktree_id)?;
183 let relative_path: Arc<Path> = project_path.path.as_std_path().into();
184 Some((project_path, relative_path))
185 }) {
186 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
187 let buffer = project
188 .update(cx, |project, cx| {
189 project.open_buffer(project_path.clone(), cx)
190 })
191 .await?;
192 let diff = git_store
193 .update(cx, |git_store, cx| {
194 git_store.open_uncommitted_diff(buffer.clone(), cx)
195 })
196 .await?;
197 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
198 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
199 }
200 }
201 }
202 Ok(snapshots_by_path)
203}
204
205fn compute_uncommitted_diff(
206 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
207) -> String {
208 let mut uncommitted_diff = String::new();
209 for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
210 if let Some(head_text) = &diff_snapshot.base_text_string() {
211 let file_diff = language::unified_diff(head_text, &before_text.text());
212 if !file_diff.is_empty() {
213 let path_str = relative_path.to_string_lossy();
214 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
215 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
216 uncommitted_diff.push_str(&file_diff);
217 if !uncommitted_diff.ends_with('\n') {
218 uncommitted_diff.push('\n');
219 }
220 }
221 }
222 }
223 uncommitted_diff
224}
225
226fn generate_timestamp_name() -> String {
227 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
228 match format {
229 Ok(format) => {
230 let now = time::OffsetDateTime::now_local()
231 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
232 now.format(&format)
233 .unwrap_or_else(|_| "unknown-time".to_string())
234 }
235 Err(_) => "unknown-time".to_string(),
236 }
237}
238
239pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
240 let default_rate = if cx.is_staff() {
241 DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
242 } else {
243 DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
244 };
245 let capture_rate = language::language_settings::all_language_settings(None, cx)
246 .edit_predictions
247 .example_capture_rate
248 .unwrap_or(default_rate);
249 cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
250 && rand::random::<u16>() % 10_000 < capture_rate
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use crate::EditPredictionStore;
257 use client::{Client, UserStore};
258 use clock::FakeSystemClock;
259 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
260 use indoc::indoc;
261 use language::{Anchor, Point};
262 use project::{FakeFs, Project};
263 use serde_json::json;
264 use settings::SettingsStore;
265 use std::path::Path;
266
267 #[gpui::test]
268 async fn test_capture_example(cx: &mut TestAppContext) {
269 init_test(cx);
270 let fs = FakeFs::new(cx.executor());
271
272 let committed_contents = indoc! {"
273 fn main() {
274 one();
275 two();
276 three();
277 four();
278 five();
279 six();
280 seven();
281 eight();
282 nine();
283 }
284 "};
285
286 let disk_contents = indoc! {"
287 fn main() {
288 // comment 1
289 one();
290 two();
291 three();
292 four();
293 five();
294 six();
295 seven();
296 eight();
297 // comment 2
298 nine();
299 }
300 "};
301
302 fs.insert_tree(
303 "/project",
304 json!({
305 ".git": {},
306 "src": {
307 "main.rs": disk_contents,
308 }
309 }),
310 )
311 .await;
312
313 // Create an external file outside the main project
314 fs.insert_tree(
315 "/external",
316 json!({
317 "external.rs": "fn external() {}\n",
318 }),
319 )
320 .await;
321
322 fs.set_head_for_repo(
323 Path::new("/project/.git"),
324 &[("src/main.rs", committed_contents.to_string())],
325 "abc123def456",
326 );
327 fs.set_remote_for_repo(
328 Path::new("/project/.git"),
329 "origin",
330 "https://github.com/test/repo.git",
331 );
332
333 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
334
335 let buffer = project
336 .update(cx, |project, cx| {
337 project.open_local_buffer("/project/src/main.rs", cx)
338 })
339 .await
340 .unwrap();
341
342 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
343 ep_store.update(cx, |ep_store, cx| {
344 ep_store.register_buffer(&buffer, &project, cx)
345 });
346 cx.run_until_parked();
347
348 buffer.update(cx, |buffer, cx| {
349 let point = Point::new(6, 0);
350 buffer.edit([(point..point, " // comment 3\n")], None, cx);
351 let point = Point::new(4, 0);
352 buffer.edit([(point..point, " // comment 4\n")], None, cx);
353
354 pretty_assertions::assert_eq!(
355 buffer.text(),
356 indoc! {"
357 fn main() {
358 // comment 1
359 one();
360 two();
361 // comment 4
362 three();
363 four();
364 // comment 3
365 five();
366 six();
367 seven();
368 eight();
369 // comment 2
370 nine();
371 }
372 "}
373 );
374 });
375 cx.run_until_parked();
376
377 // Open and edit an external file (outside the main project's worktree)
378 let external_buffer = project
379 .update(cx, |project, cx| {
380 project.open_local_buffer("/external/external.rs", cx)
381 })
382 .await
383 .unwrap();
384 ep_store.update(cx, |ep_store, cx| {
385 ep_store.register_buffer(&external_buffer, &project, cx)
386 });
387 cx.run_until_parked();
388 external_buffer.update(cx, |buffer, cx| {
389 let point = Point::new(0, 0);
390 buffer.edit([(point..point, "// external edit\n")], None, cx);
391 });
392 cx.run_until_parked();
393
394 // Verify the external edit was recorded in events
395 let events = ep_store.update(cx, |store, cx| {
396 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
397 });
398 assert!(
399 matches!(
400 events
401 .last()
402 .unwrap()
403 .event
404 .as_ref(),
405 zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
406 ),
407 "external file edit should be in events"
408 );
409
410 let mut example = cx
411 .update(|cx| {
412 capture_example(
413 project.clone(),
414 buffer.clone(),
415 Anchor::MIN,
416 events,
417 true,
418 cx,
419 )
420 .unwrap()
421 })
422 .await
423 .unwrap();
424 example.name = "test".to_string();
425
426 pretty_assertions::assert_eq!(
427 example,
428 ExampleSpec {
429 name: "test".to_string(),
430 repository_url: "https://github.com/test/repo.git".to_string(),
431 revision: "abc123def456".to_string(),
432 tags: Vec::new(),
433 reasoning: None,
434 uncommitted_diff: indoc! {"
435 --- a/src/main.rs
436 +++ b/src/main.rs
437 @@ -1,4 +1,5 @@
438 fn main() {
439 + // comment 1
440 one();
441 two();
442 three();
443 @@ -7,5 +8,6 @@
444 six();
445 seven();
446 eight();
447 + // comment 2
448 nine();
449 }
450 "}
451 .to_string(),
452 cursor_path: Path::new("src/main.rs").into(),
453 cursor_position: indoc! {"
454 fn main() {
455 ^[CURSOR_POSITION]
456 // comment 1
457 one();
458 two();
459 // comment 4
460 three();
461 four();
462 // comment 3
463 five();
464 six();
465 seven();
466 eight();
467 // comment 2
468 nine();
469 }
470 "}
471 .to_string(),
472 edit_history: indoc! {"
473 --- a/src/main.rs
474 +++ b/src/main.rs
475 @@ -2,8 +2,10 @@
476 // comment 1
477 one();
478 two();
479 + // comment 4
480 three();
481 four();
482 + // comment 3
483 five();
484 six();
485 seven();
486 "}
487 .to_string(),
488 expected_patches: vec![
489 indoc! {"
490 --- a/src/main.rs
491 +++ b/src/main.rs
492 @@ -1,16 +1,16 @@
493 fn main() {
494 // comment 1
495 one();
496 two();
497 // comment 4
498 three();
499 four();
500 // comment 3
501 five();
502 six();
503 seven();
504 eight();
505 // comment 2
506 nine();
507 }
508 "}
509 .to_string()
510 ],
511 rejected_patch: Some(
512 indoc! {"
513 --- a/src/main.rs
514 +++ b/src/main.rs
515 @@ -1,16 +1,16 @@
516 fn main() {
517 // comment 1
518 one();
519 two();
520 // comment 4
521 three();
522 four();
523 // comment 3
524 five();
525 six();
526 seven();
527 eight();
528 // comment 2
529 nine();
530 }
531 "}
532 .to_string()
533 )
534 }
535 );
536 }
537
538 fn init_test(cx: &mut TestAppContext) {
539 cx.update(|cx| {
540 let settings_store = SettingsStore::test(cx);
541 cx.set_global(settings_store);
542 zlog::init_test();
543 let http_client = FakeHttpClient::with_404_response();
544 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
545 language_model::init(client.clone(), cx);
546 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
547 EditPredictionStore::global(&client, &user_store, cx);
548 })
549 }
550}