1use crate::{
2 EditPredictionExampleCaptureFeatureFlag, StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
4};
5use anyhow::Result;
6use buffer_diff::BufferDiffSnapshot;
7use collections::HashMap;
8use feature_flags::FeatureFlagAppExt as _;
9use gpui::{App, Entity, Task};
10use language::{Buffer, ToPoint as _};
11use project::{Project, WorktreeId};
12use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
13use text::{BufferSnapshot as TextBufferSnapshot, Point};
14
15pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
16
17pub fn capture_example(
18 project: Entity<Project>,
19 buffer: Entity<Buffer>,
20 cursor_anchor: language::Anchor,
21 mut events: Vec<StoredEvent>,
22 populate_expected_patch: bool,
23 cx: &mut App,
24) -> Option<Task<Result<ExampleSpec>>> {
25 let snapshot = buffer.read(cx).snapshot();
26 let file = snapshot.file()?;
27 let worktree_id = file.worktree_id(cx);
28 let repository = project.read(cx).active_repository(cx)?;
29 let repository_snapshot = repository.read(cx).snapshot();
30 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
31 let root_name = worktree.read(cx).root_name_str().to_owned();
32 let cursor_path: Arc<Path> = file.path().as_std_path().into();
33 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
34 return None;
35 }
36
37 let repository_url = repository_snapshot
38 .remote_origin_url
39 .clone()
40 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
41 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
42
43 let git_store = project.read(cx).git_store().clone();
44
45 Some(cx.spawn(async move |mut cx| {
46 let snapshots_by_path =
47 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
48
49 events.retain(|stored_event| {
50 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
51 let relative_path = strip_root_name(path, &root_name);
52 snapshots_by_path.contains_key(relative_path)
53 });
54
55 let line_comment_prefix = snapshot
56 .language()
57 .and_then(|lang| lang.config().line_comments.first())
58 .map(|s| s.to_string())
59 .unwrap_or_default();
60 let (cursor_excerpt, cursor_offset, cursor_excerpt_range) = cx
61 .background_executor()
62 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
63 .await;
64 let uncommitted_diff = cx
65 .background_executor()
66 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
67 .await;
68
69 let mut edit_history = String::new();
70 for stored_event in &events {
71 write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
72 if !edit_history.ends_with('\n') {
73 edit_history.push('\n');
74 }
75 }
76
77 // Initialize an empty patch with context lines, to make it easy
78 // to write the expected patch by hand.
79 let mut expected_patches = Vec::new();
80 if populate_expected_patch {
81 let mut empty_patch = String::new();
82 let start_row = cursor_excerpt_range.start.row + 1;
83 let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
84 writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
85 writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
86 writeln!(
87 &mut empty_patch,
88 "@@ -{},{} +{},{} @@",
89 start_row, row_count, start_row, row_count,
90 )
91 .ok();
92 for line in cursor_excerpt.lines() {
93 writeln!(&mut empty_patch, " {}", line).ok();
94 }
95 expected_patches.push(empty_patch);
96 }
97
98 let mut spec = ExampleSpec {
99 name: generate_timestamp_name(),
100 repository_url,
101 revision,
102 tags: Vec::new(),
103 reasoning: None,
104 uncommitted_diff,
105 cursor_path,
106 cursor_position: String::new(),
107 edit_history,
108 expected_patches,
109 };
110 spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
111 Ok(spec)
112 }))
113}
114
115fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
116 path.strip_prefix(root_name).unwrap_or(path)
117}
118
119fn write_event_with_relative_paths(
120 output: &mut String,
121 event: &zeta_prompt::Event,
122 root_name: &str,
123) {
124 fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
125 for component in strip_root_name(path, root_name).components() {
126 output.push('/');
127 write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
128 }
129 }
130
131 let zeta_prompt::Event::BufferChange {
132 path,
133 old_path,
134 diff,
135 ..
136 } = event;
137
138 output.push_str("--- a");
139 write_relative_path(output, old_path.as_ref(), root_name);
140 output.push_str("\n+++ b");
141 write_relative_path(output, path.as_ref(), root_name);
142 output.push('\n');
143 output.push_str(diff);
144}
145
146fn compute_cursor_excerpt(
147 snapshot: &language::BufferSnapshot,
148 cursor_anchor: language::Anchor,
149) -> (String, usize, Range<Point>) {
150 use text::ToOffset as _;
151
152 let cursor_point = cursor_anchor.to_point(snapshot);
153 let (_editable_range, context_range) =
154 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
155 let context_start_offset = context_range.start.to_offset(snapshot);
156 let cursor_offset = cursor_anchor.to_offset(snapshot);
157 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
158 let excerpt = snapshot
159 .text_for_range(context_range.clone())
160 .collect::<String>();
161 (excerpt, cursor_offset_in_excerpt, context_range)
162}
163
164async fn collect_snapshots(
165 project: &Entity<Project>,
166 git_store: &Entity<project::git_store::GitStore>,
167 worktree_id: WorktreeId,
168 events: &[StoredEvent],
169 cx: &mut gpui::AsyncApp,
170) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
171 let mut snapshots_by_path = HashMap::default();
172 for stored_event in events {
173 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
174 if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
175 let project_path = project
176 .find_project_path(path, cx)
177 .filter(|path| path.worktree_id == worktree_id)?;
178 let relative_path: Arc<Path> = project_path.path.as_std_path().into();
179 Some((project_path, relative_path))
180 }) {
181 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
182 let buffer = project
183 .update(cx, |project, cx| {
184 project.open_buffer(project_path.clone(), cx)
185 })
186 .await?;
187 let diff = git_store
188 .update(cx, |git_store, cx| {
189 git_store.open_uncommitted_diff(buffer.clone(), cx)
190 })
191 .await?;
192 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
193 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
194 }
195 }
196 }
197 Ok(snapshots_by_path)
198}
199
200fn compute_uncommitted_diff(
201 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
202) -> String {
203 let mut uncommitted_diff = String::new();
204 for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
205 if let Some(head_text) = &diff_snapshot.base_text_string() {
206 let file_diff = language::unified_diff(head_text, &before_text.text());
207 if !file_diff.is_empty() {
208 let path_str = relative_path.to_string_lossy();
209 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
210 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
211 uncommitted_diff.push_str(&file_diff);
212 if !uncommitted_diff.ends_with('\n') {
213 uncommitted_diff.push('\n');
214 }
215 }
216 }
217 }
218 uncommitted_diff
219}
220
221fn generate_timestamp_name() -> String {
222 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
223 match format {
224 Ok(format) => {
225 let now = time::OffsetDateTime::now_local()
226 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
227 now.format(&format)
228 .unwrap_or_else(|_| "unknown-time".to_string())
229 }
230 Err(_) => "unknown-time".to_string(),
231 }
232}
233
234pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
235 let capture_rate = language::language_settings::all_language_settings(None, cx)
236 .edit_predictions
237 .example_capture_rate
238 .unwrap_or(DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS);
239 cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
240 && rand::random::<u16>() % 10_000 < capture_rate
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use crate::EditPredictionStore;
247 use client::{Client, UserStore};
248 use clock::FakeSystemClock;
249 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
250 use indoc::indoc;
251 use language::{Anchor, Point};
252 use project::{FakeFs, Project};
253 use serde_json::json;
254 use settings::SettingsStore;
255 use std::path::Path;
256
257 #[gpui::test]
258 async fn test_capture_example(cx: &mut TestAppContext) {
259 init_test(cx);
260 let fs = FakeFs::new(cx.executor());
261
262 let committed_contents = indoc! {"
263 fn main() {
264 one();
265 two();
266 three();
267 four();
268 five();
269 six();
270 seven();
271 eight();
272 nine();
273 }
274 "};
275
276 let disk_contents = indoc! {"
277 fn main() {
278 // comment 1
279 one();
280 two();
281 three();
282 four();
283 five();
284 six();
285 seven();
286 eight();
287 // comment 2
288 nine();
289 }
290 "};
291
292 fs.insert_tree(
293 "/project",
294 json!({
295 ".git": {},
296 "src": {
297 "main.rs": disk_contents,
298 }
299 }),
300 )
301 .await;
302
303 // Create an external file outside the main project
304 fs.insert_tree(
305 "/external",
306 json!({
307 "external.rs": "fn external() {}\n",
308 }),
309 )
310 .await;
311
312 fs.set_head_for_repo(
313 Path::new("/project/.git"),
314 &[("src/main.rs", committed_contents.to_string())],
315 "abc123def456",
316 );
317 fs.set_remote_for_repo(
318 Path::new("/project/.git"),
319 "origin",
320 "https://github.com/test/repo.git",
321 );
322
323 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
324
325 let buffer = project
326 .update(cx, |project, cx| {
327 project.open_local_buffer("/project/src/main.rs", cx)
328 })
329 .await
330 .unwrap();
331
332 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
333 ep_store.update(cx, |ep_store, cx| {
334 ep_store.register_buffer(&buffer, &project, cx)
335 });
336 cx.run_until_parked();
337
338 buffer.update(cx, |buffer, cx| {
339 let point = Point::new(6, 0);
340 buffer.edit([(point..point, " // comment 3\n")], None, cx);
341 let point = Point::new(4, 0);
342 buffer.edit([(point..point, " // comment 4\n")], None, cx);
343
344 pretty_assertions::assert_eq!(
345 buffer.text(),
346 indoc! {"
347 fn main() {
348 // comment 1
349 one();
350 two();
351 // comment 4
352 three();
353 four();
354 // comment 3
355 five();
356 six();
357 seven();
358 eight();
359 // comment 2
360 nine();
361 }
362 "}
363 );
364 });
365 cx.run_until_parked();
366
367 // Open and edit an external file (outside the main project's worktree)
368 let external_buffer = project
369 .update(cx, |project, cx| {
370 project.open_local_buffer("/external/external.rs", cx)
371 })
372 .await
373 .unwrap();
374 ep_store.update(cx, |ep_store, cx| {
375 ep_store.register_buffer(&external_buffer, &project, cx)
376 });
377 cx.run_until_parked();
378 external_buffer.update(cx, |buffer, cx| {
379 let point = Point::new(0, 0);
380 buffer.edit([(point..point, "// external edit\n")], None, cx);
381 });
382 cx.run_until_parked();
383
384 // Verify the external edit was recorded in events
385 let events = ep_store.update(cx, |store, cx| {
386 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
387 });
388 assert!(
389 matches!(
390 events
391 .last()
392 .unwrap()
393 .event
394 .as_ref(),
395 zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
396 ),
397 "external file edit should be in events"
398 );
399
400 let mut example = cx
401 .update(|cx| {
402 capture_example(
403 project.clone(),
404 buffer.clone(),
405 Anchor::MIN,
406 events,
407 true,
408 cx,
409 )
410 .unwrap()
411 })
412 .await
413 .unwrap();
414 example.name = "test".to_string();
415
416 pretty_assertions::assert_eq!(
417 example,
418 ExampleSpec {
419 name: "test".to_string(),
420 repository_url: "https://github.com/test/repo.git".to_string(),
421 revision: "abc123def456".to_string(),
422 tags: Vec::new(),
423 reasoning: None,
424 uncommitted_diff: indoc! {"
425 --- a/src/main.rs
426 +++ b/src/main.rs
427 @@ -1,4 +1,5 @@
428 fn main() {
429 + // comment 1
430 one();
431 two();
432 three();
433 @@ -7,5 +8,6 @@
434 six();
435 seven();
436 eight();
437 + // comment 2
438 nine();
439 }
440 "}
441 .to_string(),
442 cursor_path: Path::new("src/main.rs").into(),
443 cursor_position: indoc! {"
444 fn main() {
445 ^[CURSOR_POSITION]
446 // comment 1
447 one();
448 two();
449 // comment 4
450 three();
451 four();
452 // comment 3
453 five();
454 six();
455 seven();
456 eight();
457 // comment 2
458 nine();
459 }
460 "}
461 .to_string(),
462 edit_history: indoc! {"
463 --- a/src/main.rs
464 +++ b/src/main.rs
465 @@ -2,8 +2,10 @@
466 // comment 1
467 one();
468 two();
469 + // comment 4
470 three();
471 four();
472 + // comment 3
473 five();
474 six();
475 seven();
476 "}
477 .to_string(),
478 expected_patches: vec![
479 indoc! {"
480 --- a/src/main.rs
481 +++ b/src/main.rs
482 @@ -1,16 +1,16 @@
483 fn main() {
484 // comment 1
485 one();
486 two();
487 // comment 4
488 three();
489 four();
490 // comment 3
491 five();
492 six();
493 seven();
494 eight();
495 // comment 2
496 nine();
497 }
498 "}
499 .to_string()
500 ]
501 }
502 );
503 }
504
505 fn init_test(cx: &mut TestAppContext) {
506 cx.update(|cx| {
507 let settings_store = SettingsStore::test(cx);
508 cx.set_global(settings_store);
509 zlog::init_test();
510 let http_client = FakeHttpClient::with_404_response();
511 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
512 language_model::init(client.clone(), cx);
513 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
514 EditPredictionStore::global(&client, &user_store, cx);
515 })
516 }
517}