1use crate::{
2 EditPredictionExampleCaptureFeatureFlag, StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
4};
5use anyhow::Result;
6use buffer_diff::BufferDiffSnapshot;
7use collections::HashMap;
8use feature_flags::FeatureFlagAppExt as _;
9use gpui::{App, Entity, Task};
10use language::{Buffer, ToPoint as _};
11use project::{Project, WorktreeId};
12use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
13use text::{BufferSnapshot as TextBufferSnapshot, Point};
14
15pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
16pub(crate) const ZETA2_TESTING_RATE_PER_10K_PREDICTION: u16 = 500;
17
18pub fn capture_example(
19 project: Entity<Project>,
20 buffer: Entity<Buffer>,
21 cursor_anchor: language::Anchor,
22 mut events: Vec<StoredEvent>,
23 populate_expected_patch: bool,
24 cx: &mut App,
25) -> Option<Task<Result<ExampleSpec>>> {
26 let snapshot = buffer.read(cx).snapshot();
27 let file = snapshot.file()?;
28 let worktree_id = file.worktree_id(cx);
29 let repository = project.read(cx).active_repository(cx)?;
30 let repository_snapshot = repository.read(cx).snapshot();
31 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
32 let root_name = worktree.read(cx).root_name_str().to_owned();
33 let cursor_path: Arc<Path> = file.path().as_std_path().into();
34 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
35 return None;
36 }
37
38 let repository_url = repository_snapshot
39 .remote_origin_url
40 .clone()
41 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
42 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
43
44 let git_store = project.read(cx).git_store().clone();
45
46 Some(cx.spawn(async move |mut cx| {
47 let snapshots_by_path =
48 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
49
50 events.retain(|stored_event| {
51 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
52 let relative_path = strip_root_name(path, &root_name);
53 snapshots_by_path.contains_key(relative_path)
54 });
55
56 let line_comment_prefix = snapshot
57 .language()
58 .and_then(|lang| lang.config().line_comments.first())
59 .map(|s| s.to_string())
60 .unwrap_or_default();
61 let (cursor_excerpt, cursor_offset, cursor_excerpt_range) = cx
62 .background_executor()
63 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
64 .await;
65 let uncommitted_diff = cx
66 .background_executor()
67 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
68 .await;
69
70 let mut edit_history = String::new();
71 for stored_event in &events {
72 write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
73 if !edit_history.ends_with('\n') {
74 edit_history.push('\n');
75 }
76 }
77
78 // Initialize an empty patch with context lines, to make it easy
79 // to write the expected patch by hand.
80 let mut expected_patches = Vec::new();
81 if populate_expected_patch {
82 let mut empty_patch = String::new();
83 let start_row = cursor_excerpt_range.start.row + 1;
84 let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
85 writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
86 writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
87 writeln!(
88 &mut empty_patch,
89 "@@ -{},{} +{},{} @@",
90 start_row, row_count, start_row, row_count,
91 )
92 .ok();
93 for line in cursor_excerpt.lines() {
94 writeln!(&mut empty_patch, " {}", line).ok();
95 }
96 expected_patches.push(empty_patch);
97 }
98
99 let mut spec = ExampleSpec {
100 name: generate_timestamp_name(),
101 repository_url,
102 revision,
103 tags: Vec::new(),
104 reasoning: None,
105 uncommitted_diff,
106 cursor_path,
107 cursor_position: String::new(),
108 edit_history,
109 expected_patches,
110 };
111 spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
112 Ok(spec)
113 }))
114}
115
116fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
117 path.strip_prefix(root_name).unwrap_or(path)
118}
119
120fn write_event_with_relative_paths(
121 output: &mut String,
122 event: &zeta_prompt::Event,
123 root_name: &str,
124) {
125 fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
126 for component in strip_root_name(path, root_name).components() {
127 output.push('/');
128 write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
129 }
130 }
131
132 let zeta_prompt::Event::BufferChange {
133 path,
134 old_path,
135 diff,
136 ..
137 } = event;
138
139 output.push_str("--- a");
140 write_relative_path(output, old_path.as_ref(), root_name);
141 output.push_str("\n+++ b");
142 write_relative_path(output, path.as_ref(), root_name);
143 output.push('\n');
144 output.push_str(diff);
145}
146
147fn compute_cursor_excerpt(
148 snapshot: &language::BufferSnapshot,
149 cursor_anchor: language::Anchor,
150) -> (String, usize, Range<Point>) {
151 use text::ToOffset as _;
152
153 let cursor_point = cursor_anchor.to_point(snapshot);
154 let (_editable_range, context_range) =
155 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
156 let context_start_offset = context_range.start.to_offset(snapshot);
157 let cursor_offset = cursor_anchor.to_offset(snapshot);
158 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
159 let excerpt = snapshot
160 .text_for_range(context_range.clone())
161 .collect::<String>();
162 (excerpt, cursor_offset_in_excerpt, context_range)
163}
164
165async fn collect_snapshots(
166 project: &Entity<Project>,
167 git_store: &Entity<project::git_store::GitStore>,
168 worktree_id: WorktreeId,
169 events: &[StoredEvent],
170 cx: &mut gpui::AsyncApp,
171) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
172 let mut snapshots_by_path = HashMap::default();
173 for stored_event in events {
174 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
175 if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
176 let project_path = project
177 .find_project_path(path, cx)
178 .filter(|path| path.worktree_id == worktree_id)?;
179 let relative_path: Arc<Path> = project_path.path.as_std_path().into();
180 Some((project_path, relative_path))
181 }) {
182 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
183 let buffer = project
184 .update(cx, |project, cx| {
185 project.open_buffer(project_path.clone(), cx)
186 })
187 .await?;
188 let diff = git_store
189 .update(cx, |git_store, cx| {
190 git_store.open_uncommitted_diff(buffer.clone(), cx)
191 })
192 .await?;
193 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
194 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
195 }
196 }
197 }
198 Ok(snapshots_by_path)
199}
200
201fn compute_uncommitted_diff(
202 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
203) -> String {
204 let mut uncommitted_diff = String::new();
205 for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
206 if let Some(head_text) = &diff_snapshot.base_text_string() {
207 let file_diff = language::unified_diff(head_text, &before_text.text());
208 if !file_diff.is_empty() {
209 let path_str = relative_path.to_string_lossy();
210 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
211 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
212 uncommitted_diff.push_str(&file_diff);
213 if !uncommitted_diff.ends_with('\n') {
214 uncommitted_diff.push('\n');
215 }
216 }
217 }
218 }
219 uncommitted_diff
220}
221
222fn generate_timestamp_name() -> String {
223 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
224 match format {
225 Ok(format) => {
226 let now = time::OffsetDateTime::now_local()
227 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
228 now.format(&format)
229 .unwrap_or_else(|_| "unknown-time".to_string())
230 }
231 Err(_) => "unknown-time".to_string(),
232 }
233}
234
235pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
236 let capture_rate = language::language_settings::all_language_settings(None, cx)
237 .edit_predictions
238 .example_capture_rate
239 .unwrap_or(DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS);
240 cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
241 && rand::random::<u16>() % 10_000 < capture_rate
242}
243
244pub(crate) fn should_send_testing_zeta2_request() -> bool {
245 rand::random::<u16>() % 10_000 < ZETA2_TESTING_RATE_PER_10K_PREDICTION
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use crate::EditPredictionStore;
252 use client::{Client, UserStore};
253 use clock::FakeSystemClock;
254 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
255 use indoc::indoc;
256 use language::{Anchor, Point};
257 use project::{FakeFs, Project};
258 use serde_json::json;
259 use settings::SettingsStore;
260 use std::path::Path;
261
262 #[gpui::test]
263 async fn test_capture_example(cx: &mut TestAppContext) {
264 init_test(cx);
265 let fs = FakeFs::new(cx.executor());
266
267 let committed_contents = indoc! {"
268 fn main() {
269 one();
270 two();
271 three();
272 four();
273 five();
274 six();
275 seven();
276 eight();
277 nine();
278 }
279 "};
280
281 let disk_contents = indoc! {"
282 fn main() {
283 // comment 1
284 one();
285 two();
286 three();
287 four();
288 five();
289 six();
290 seven();
291 eight();
292 // comment 2
293 nine();
294 }
295 "};
296
297 fs.insert_tree(
298 "/project",
299 json!({
300 ".git": {},
301 "src": {
302 "main.rs": disk_contents,
303 }
304 }),
305 )
306 .await;
307
308 // Create an external file outside the main project
309 fs.insert_tree(
310 "/external",
311 json!({
312 "external.rs": "fn external() {}\n",
313 }),
314 )
315 .await;
316
317 fs.set_head_for_repo(
318 Path::new("/project/.git"),
319 &[("src/main.rs", committed_contents.to_string())],
320 "abc123def456",
321 );
322 fs.set_remote_for_repo(
323 Path::new("/project/.git"),
324 "origin",
325 "https://github.com/test/repo.git",
326 );
327
328 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
329
330 let buffer = project
331 .update(cx, |project, cx| {
332 project.open_local_buffer("/project/src/main.rs", cx)
333 })
334 .await
335 .unwrap();
336
337 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
338 ep_store.update(cx, |ep_store, cx| {
339 ep_store.register_buffer(&buffer, &project, cx)
340 });
341 cx.run_until_parked();
342
343 buffer.update(cx, |buffer, cx| {
344 let point = Point::new(6, 0);
345 buffer.edit([(point..point, " // comment 3\n")], None, cx);
346 let point = Point::new(4, 0);
347 buffer.edit([(point..point, " // comment 4\n")], None, cx);
348
349 pretty_assertions::assert_eq!(
350 buffer.text(),
351 indoc! {"
352 fn main() {
353 // comment 1
354 one();
355 two();
356 // comment 4
357 three();
358 four();
359 // comment 3
360 five();
361 six();
362 seven();
363 eight();
364 // comment 2
365 nine();
366 }
367 "}
368 );
369 });
370 cx.run_until_parked();
371
372 // Open and edit an external file (outside the main project's worktree)
373 let external_buffer = project
374 .update(cx, |project, cx| {
375 project.open_local_buffer("/external/external.rs", cx)
376 })
377 .await
378 .unwrap();
379 ep_store.update(cx, |ep_store, cx| {
380 ep_store.register_buffer(&external_buffer, &project, cx)
381 });
382 cx.run_until_parked();
383 external_buffer.update(cx, |buffer, cx| {
384 let point = Point::new(0, 0);
385 buffer.edit([(point..point, "// external edit\n")], None, cx);
386 });
387 cx.run_until_parked();
388
389 // Verify the external edit was recorded in events
390 let events = ep_store.update(cx, |store, cx| {
391 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
392 });
393 assert!(
394 matches!(
395 events
396 .last()
397 .unwrap()
398 .event
399 .as_ref(),
400 zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
401 ),
402 "external file edit should be in events"
403 );
404
405 let mut example = cx
406 .update(|cx| {
407 capture_example(
408 project.clone(),
409 buffer.clone(),
410 Anchor::MIN,
411 events,
412 true,
413 cx,
414 )
415 .unwrap()
416 })
417 .await
418 .unwrap();
419 example.name = "test".to_string();
420
421 pretty_assertions::assert_eq!(
422 example,
423 ExampleSpec {
424 name: "test".to_string(),
425 repository_url: "https://github.com/test/repo.git".to_string(),
426 revision: "abc123def456".to_string(),
427 tags: Vec::new(),
428 reasoning: None,
429 uncommitted_diff: indoc! {"
430 --- a/src/main.rs
431 +++ b/src/main.rs
432 @@ -1,4 +1,5 @@
433 fn main() {
434 + // comment 1
435 one();
436 two();
437 three();
438 @@ -7,5 +8,6 @@
439 six();
440 seven();
441 eight();
442 + // comment 2
443 nine();
444 }
445 "}
446 .to_string(),
447 cursor_path: Path::new("src/main.rs").into(),
448 cursor_position: indoc! {"
449 fn main() {
450 ^[CURSOR_POSITION]
451 // comment 1
452 one();
453 two();
454 // comment 4
455 three();
456 four();
457 // comment 3
458 five();
459 six();
460 seven();
461 eight();
462 // comment 2
463 nine();
464 }
465 "}
466 .to_string(),
467 edit_history: indoc! {"
468 --- a/src/main.rs
469 +++ b/src/main.rs
470 @@ -2,8 +2,10 @@
471 // comment 1
472 one();
473 two();
474 + // comment 4
475 three();
476 four();
477 + // comment 3
478 five();
479 six();
480 seven();
481 "}
482 .to_string(),
483 expected_patches: vec![
484 indoc! {"
485 --- a/src/main.rs
486 +++ b/src/main.rs
487 @@ -1,16 +1,16 @@
488 fn main() {
489 // comment 1
490 one();
491 two();
492 // comment 4
493 three();
494 four();
495 // comment 3
496 five();
497 six();
498 seven();
499 eight();
500 // comment 2
501 nine();
502 }
503 "}
504 .to_string()
505 ]
506 }
507 );
508 }
509
510 fn init_test(cx: &mut TestAppContext) {
511 cx.update(|cx| {
512 let settings_store = SettingsStore::test(cx);
513 cx.set_global(settings_store);
514 zlog::init_test();
515 let http_client = FakeHttpClient::with_404_response();
516 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
517 language_model::init(client.clone(), cx);
518 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
519 EditPredictionStore::global(&client, &user_store, cx);
520 })
521 }
522}