1use crate::{
2 EditPredictionExampleCaptureFeatureFlag, StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
4};
5use anyhow::Result;
6use buffer_diff::BufferDiffSnapshot;
7use collections::HashMap;
8use feature_flags::FeatureFlagAppExt as _;
9use gpui::{App, Entity, Task};
10use language::{Buffer, ToPoint as _};
11use project::{Project, WorktreeId};
12use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
13use text::{BufferSnapshot as TextBufferSnapshot, Point};
14
15pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
16pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
17pub(crate) const ZETA2_TESTING_RATE_PER_10K_PREDICTION: u16 = 500;
18
19pub fn capture_example(
20 project: Entity<Project>,
21 buffer: Entity<Buffer>,
22 cursor_anchor: language::Anchor,
23 mut events: Vec<StoredEvent>,
24 populate_expected_patch: bool,
25 cx: &mut App,
26) -> Option<Task<Result<ExampleSpec>>> {
27 let snapshot = buffer.read(cx).snapshot();
28 let file = snapshot.file()?;
29 let worktree_id = file.worktree_id(cx);
30 let repository = project.read(cx).active_repository(cx)?;
31 let repository_snapshot = repository.read(cx).snapshot();
32 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
33 let root_name = worktree.read(cx).root_name_str().to_owned();
34 let cursor_path: Arc<Path> = file.path().as_std_path().into();
35 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
36 return None;
37 }
38
39 let repository_url = repository_snapshot
40 .remote_origin_url
41 .clone()
42 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
43 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
44
45 let git_store = project.read(cx).git_store().clone();
46
47 Some(cx.spawn(async move |mut cx| {
48 let snapshots_by_path =
49 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
50
51 events.retain(|stored_event| {
52 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
53 let relative_path = strip_root_name(path, &root_name);
54 snapshots_by_path.contains_key(relative_path)
55 });
56
57 let line_comment_prefix = snapshot
58 .language()
59 .and_then(|lang| lang.config().line_comments.first())
60 .map(|s| s.to_string())
61 .unwrap_or_default();
62 let (cursor_excerpt, cursor_offset, cursor_excerpt_range) = cx
63 .background_executor()
64 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
65 .await;
66 let uncommitted_diff = cx
67 .background_executor()
68 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
69 .await;
70
71 let mut edit_history = String::new();
72 for stored_event in &events {
73 write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
74 if !edit_history.ends_with('\n') {
75 edit_history.push('\n');
76 }
77 }
78
79 // Initialize an empty patch with context lines, to make it easy
80 // to write the expected patch by hand.
81 let mut expected_patches = Vec::new();
82 let mut rejected_patch = None;
83 if populate_expected_patch {
84 let mut empty_patch = String::new();
85 let start_row = cursor_excerpt_range.start.row + 1;
86 let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
87 writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
88 writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
89 writeln!(
90 &mut empty_patch,
91 "@@ -{},{} +{},{} @@",
92 start_row, row_count, start_row, row_count,
93 )
94 .ok();
95 for line in cursor_excerpt.lines() {
96 writeln!(&mut empty_patch, " {}", line).ok();
97 }
98
99 expected_patches.push(empty_patch.clone());
100 rejected_patch = Some(empty_patch);
101 }
102
103 let mut spec = ExampleSpec {
104 name: generate_timestamp_name(),
105 repository_url,
106 revision,
107 tags: Vec::new(),
108 reasoning: None,
109 uncommitted_diff,
110 cursor_path,
111 cursor_position: String::new(),
112 edit_history,
113 expected_patches,
114 rejected_patch,
115 };
116 spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
117 Ok(spec)
118 }))
119}
120
121fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
122 path.strip_prefix(root_name).unwrap_or(path)
123}
124
125fn write_event_with_relative_paths(
126 output: &mut String,
127 event: &zeta_prompt::Event,
128 root_name: &str,
129) {
130 fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
131 for component in strip_root_name(path, root_name).components() {
132 output.push('/');
133 write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
134 }
135 }
136
137 let zeta_prompt::Event::BufferChange {
138 path,
139 old_path,
140 diff,
141 ..
142 } = event;
143
144 output.push_str("--- a");
145 write_relative_path(output, old_path.as_ref(), root_name);
146 output.push_str("\n+++ b");
147 write_relative_path(output, path.as_ref(), root_name);
148 output.push('\n');
149 output.push_str(diff);
150}
151
152fn compute_cursor_excerpt(
153 snapshot: &language::BufferSnapshot,
154 cursor_anchor: language::Anchor,
155) -> (String, usize, Range<Point>) {
156 use text::ToOffset as _;
157
158 let cursor_point = cursor_anchor.to_point(snapshot);
159 let (_editable_range, context_range) =
160 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
161 let context_start_offset = context_range.start.to_offset(snapshot);
162 let cursor_offset = cursor_anchor.to_offset(snapshot);
163 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
164 let excerpt = snapshot
165 .text_for_range(context_range.clone())
166 .collect::<String>();
167 (excerpt, cursor_offset_in_excerpt, context_range)
168}
169
170async fn collect_snapshots(
171 project: &Entity<Project>,
172 git_store: &Entity<project::git_store::GitStore>,
173 worktree_id: WorktreeId,
174 events: &[StoredEvent],
175 cx: &mut gpui::AsyncApp,
176) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
177 let mut snapshots_by_path = HashMap::default();
178 for stored_event in events {
179 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
180 if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
181 let project_path = project
182 .find_project_path(path, cx)
183 .filter(|path| path.worktree_id == worktree_id)?;
184 let relative_path: Arc<Path> = project_path.path.as_std_path().into();
185 Some((project_path, relative_path))
186 }) {
187 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
188 let buffer = project
189 .update(cx, |project, cx| {
190 project.open_buffer(project_path.clone(), cx)
191 })
192 .await?;
193 let diff = git_store
194 .update(cx, |git_store, cx| {
195 git_store.open_uncommitted_diff(buffer.clone(), cx)
196 })
197 .await?;
198 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
199 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
200 }
201 }
202 }
203 Ok(snapshots_by_path)
204}
205
206fn compute_uncommitted_diff(
207 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
208) -> String {
209 let mut uncommitted_diff = String::new();
210 for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
211 if let Some(head_text) = &diff_snapshot.base_text_string() {
212 let file_diff = language::unified_diff(head_text, &before_text.text());
213 if !file_diff.is_empty() {
214 let path_str = relative_path.to_string_lossy();
215 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
216 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
217 uncommitted_diff.push_str(&file_diff);
218 if !uncommitted_diff.ends_with('\n') {
219 uncommitted_diff.push('\n');
220 }
221 }
222 }
223 }
224 uncommitted_diff
225}
226
227fn generate_timestamp_name() -> String {
228 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
229 match format {
230 Ok(format) => {
231 let now = time::OffsetDateTime::now_local()
232 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
233 now.format(&format)
234 .unwrap_or_else(|_| "unknown-time".to_string())
235 }
236 Err(_) => "unknown-time".to_string(),
237 }
238}
239
240pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
241 let default_rate = if cx.is_staff() {
242 DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
243 } else {
244 DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
245 };
246 let capture_rate = language::language_settings::all_language_settings(None, cx)
247 .edit_predictions
248 .example_capture_rate
249 .unwrap_or(default_rate);
250 cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
251 && rand::random::<u16>() % 10_000 < capture_rate
252}
253
254pub(crate) fn should_send_testing_zeta2_request() -> bool {
255 rand::random::<u16>() % 10_000 < ZETA2_TESTING_RATE_PER_10K_PREDICTION
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use crate::EditPredictionStore;
262 use client::{Client, UserStore};
263 use clock::FakeSystemClock;
264 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
265 use indoc::indoc;
266 use language::{Anchor, Point};
267 use project::{FakeFs, Project};
268 use serde_json::json;
269 use settings::SettingsStore;
270 use std::path::Path;
271
272 #[gpui::test]
273 async fn test_capture_example(cx: &mut TestAppContext) {
274 init_test(cx);
275 let fs = FakeFs::new(cx.executor());
276
277 let committed_contents = indoc! {"
278 fn main() {
279 one();
280 two();
281 three();
282 four();
283 five();
284 six();
285 seven();
286 eight();
287 nine();
288 }
289 "};
290
291 let disk_contents = indoc! {"
292 fn main() {
293 // comment 1
294 one();
295 two();
296 three();
297 four();
298 five();
299 six();
300 seven();
301 eight();
302 // comment 2
303 nine();
304 }
305 "};
306
307 fs.insert_tree(
308 "/project",
309 json!({
310 ".git": {},
311 "src": {
312 "main.rs": disk_contents,
313 }
314 }),
315 )
316 .await;
317
318 // Create an external file outside the main project
319 fs.insert_tree(
320 "/external",
321 json!({
322 "external.rs": "fn external() {}\n",
323 }),
324 )
325 .await;
326
327 fs.set_head_for_repo(
328 Path::new("/project/.git"),
329 &[("src/main.rs", committed_contents.to_string())],
330 "abc123def456",
331 );
332 fs.set_remote_for_repo(
333 Path::new("/project/.git"),
334 "origin",
335 "https://github.com/test/repo.git",
336 );
337
338 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
339
340 let buffer = project
341 .update(cx, |project, cx| {
342 project.open_local_buffer("/project/src/main.rs", cx)
343 })
344 .await
345 .unwrap();
346
347 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
348 ep_store.update(cx, |ep_store, cx| {
349 ep_store.register_buffer(&buffer, &project, cx)
350 });
351 cx.run_until_parked();
352
353 buffer.update(cx, |buffer, cx| {
354 let point = Point::new(6, 0);
355 buffer.edit([(point..point, " // comment 3\n")], None, cx);
356 let point = Point::new(4, 0);
357 buffer.edit([(point..point, " // comment 4\n")], None, cx);
358
359 pretty_assertions::assert_eq!(
360 buffer.text(),
361 indoc! {"
362 fn main() {
363 // comment 1
364 one();
365 two();
366 // comment 4
367 three();
368 four();
369 // comment 3
370 five();
371 six();
372 seven();
373 eight();
374 // comment 2
375 nine();
376 }
377 "}
378 );
379 });
380 cx.run_until_parked();
381
382 // Open and edit an external file (outside the main project's worktree)
383 let external_buffer = project
384 .update(cx, |project, cx| {
385 project.open_local_buffer("/external/external.rs", cx)
386 })
387 .await
388 .unwrap();
389 ep_store.update(cx, |ep_store, cx| {
390 ep_store.register_buffer(&external_buffer, &project, cx)
391 });
392 cx.run_until_parked();
393 external_buffer.update(cx, |buffer, cx| {
394 let point = Point::new(0, 0);
395 buffer.edit([(point..point, "// external edit\n")], None, cx);
396 });
397 cx.run_until_parked();
398
399 // Verify the external edit was recorded in events
400 let events = ep_store.update(cx, |store, cx| {
401 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
402 });
403 assert!(
404 matches!(
405 events
406 .last()
407 .unwrap()
408 .event
409 .as_ref(),
410 zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
411 ),
412 "external file edit should be in events"
413 );
414
415 let mut example = cx
416 .update(|cx| {
417 capture_example(
418 project.clone(),
419 buffer.clone(),
420 Anchor::MIN,
421 events,
422 true,
423 cx,
424 )
425 .unwrap()
426 })
427 .await
428 .unwrap();
429 example.name = "test".to_string();
430
431 pretty_assertions::assert_eq!(
432 example,
433 ExampleSpec {
434 name: "test".to_string(),
435 repository_url: "https://github.com/test/repo.git".to_string(),
436 revision: "abc123def456".to_string(),
437 tags: Vec::new(),
438 reasoning: None,
439 uncommitted_diff: indoc! {"
440 --- a/src/main.rs
441 +++ b/src/main.rs
442 @@ -1,4 +1,5 @@
443 fn main() {
444 + // comment 1
445 one();
446 two();
447 three();
448 @@ -7,5 +8,6 @@
449 six();
450 seven();
451 eight();
452 + // comment 2
453 nine();
454 }
455 "}
456 .to_string(),
457 cursor_path: Path::new("src/main.rs").into(),
458 cursor_position: indoc! {"
459 fn main() {
460 ^[CURSOR_POSITION]
461 // comment 1
462 one();
463 two();
464 // comment 4
465 three();
466 four();
467 // comment 3
468 five();
469 six();
470 seven();
471 eight();
472 // comment 2
473 nine();
474 }
475 "}
476 .to_string(),
477 edit_history: indoc! {"
478 --- a/src/main.rs
479 +++ b/src/main.rs
480 @@ -2,8 +2,10 @@
481 // comment 1
482 one();
483 two();
484 + // comment 4
485 three();
486 four();
487 + // comment 3
488 five();
489 six();
490 seven();
491 "}
492 .to_string(),
493 expected_patches: vec![
494 indoc! {"
495 --- a/src/main.rs
496 +++ b/src/main.rs
497 @@ -1,16 +1,16 @@
498 fn main() {
499 // comment 1
500 one();
501 two();
502 // comment 4
503 three();
504 four();
505 // comment 3
506 five();
507 six();
508 seven();
509 eight();
510 // comment 2
511 nine();
512 }
513 "}
514 .to_string()
515 ],
516 rejected_patch: Some(
517 indoc! {"
518 --- a/src/main.rs
519 +++ b/src/main.rs
520 @@ -1,16 +1,16 @@
521 fn main() {
522 // comment 1
523 one();
524 two();
525 // comment 4
526 three();
527 four();
528 // comment 3
529 five();
530 six();
531 seven();
532 eight();
533 // comment 2
534 nine();
535 }
536 "}
537 .to_string()
538 )
539 }
540 );
541 }
542
543 fn init_test(cx: &mut TestAppContext) {
544 cx.update(|cx| {
545 let settings_store = SettingsStore::test(cx);
546 cx.set_global(settings_store);
547 zlog::init_test();
548 let http_client = FakeHttpClient::with_404_response();
549 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
550 language_model::init(client.clone(), cx);
551 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
552 EditPredictionStore::global(&client, &user_store, cx);
553 })
554 }
555}