1use crate::{StoredEvent, example_spec::ExampleSpec};
2use anyhow::Result;
3use buffer_diff::BufferDiffSnapshot;
4#[cfg(test)]
5use client::RefreshLlmTokenListener;
6use collections::HashMap;
7use gpui::{App, Entity, Task};
8use language::Buffer;
9use project::{Project, WorktreeId};
10use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
11use text::{BufferSnapshot as TextBufferSnapshot, Point};
12
13pub fn capture_example(
14 project: Entity<Project>,
15 buffer: Entity<Buffer>,
16 cursor_anchor: language::Anchor,
17 mut events: Vec<StoredEvent>,
18 populate_expected_patch: bool,
19 cx: &mut App,
20) -> Option<Task<Result<ExampleSpec>>> {
21 let snapshot = buffer.read(cx).snapshot();
22 let file = snapshot.file()?;
23 let worktree_id = file.worktree_id(cx);
24 let repository = project.read(cx).active_repository(cx)?;
25 let repository_snapshot = repository.read(cx).snapshot();
26 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
27 let root_name = worktree.read(cx).root_name_str().to_owned();
28 let cursor_path: Arc<Path> = file.path().as_std_path().into();
29 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
30 return None;
31 }
32
33 let repository_url = repository_snapshot
34 .remote_origin_url
35 .clone()
36 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
37 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
38
39 let git_store = project.read(cx).git_store().clone();
40
41 Some(cx.spawn(async move |mut cx| {
42 let snapshots_by_path =
43 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
44
45 events.retain(|stored_event| {
46 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
47 let relative_path = strip_root_name(path, &root_name);
48 snapshots_by_path.contains_key(relative_path)
49 });
50
51 let line_comment_prefix = snapshot
52 .language()
53 .and_then(|lang| lang.config().line_comments.first())
54 .map(|s| s.to_string())
55 .unwrap_or_default();
56
57 let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
58 .background_executor()
59 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
60 .await;
61 let uncommitted_diff = cx
62 .background_executor()
63 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
64 .await;
65
66 let mut edit_history = String::new();
67 for stored_event in &events {
68 write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
69 if !edit_history.ends_with('\n') {
70 edit_history.push('\n');
71 }
72 }
73
74 // Initialize an empty patch with context lines, to make it easy
75 // to write the expected patch by hand.
76 let mut expected_patches = Vec::new();
77 let mut rejected_patch = None;
78 if populate_expected_patch {
79 let mut empty_patch = String::new();
80 let start_row = cursor_excerpt_range.start.row + 1;
81 let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
82 writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
83 writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
84 writeln!(
85 &mut empty_patch,
86 "@@ -{},{} +{},{} @@",
87 start_row, row_count, start_row, row_count,
88 )
89 .ok();
90 for line in cursor_excerpt.lines() {
91 writeln!(&mut empty_patch, " {}", line).ok();
92 }
93
94 expected_patches.push(empty_patch.clone());
95 rejected_patch = Some(empty_patch);
96 }
97
98 let mut spec = ExampleSpec {
99 name: generate_timestamp_name(),
100 repository_url,
101 revision,
102 tags: Vec::new(),
103 reasoning: None,
104 uncommitted_diff,
105 cursor_path,
106 cursor_position: String::new(),
107 edit_history,
108 expected_patches,
109 rejected_patch,
110 telemetry: None,
111 human_feedback: Vec::new(),
112 rating: None,
113 };
114 spec.set_cursor_excerpt(
115 &cursor_excerpt,
116 cursor_offset_in_excerpt,
117 &line_comment_prefix,
118 );
119 Ok(spec)
120 }))
121}
122
123fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
124 path.strip_prefix(root_name).unwrap_or(path)
125}
126
127fn write_event_with_relative_paths(
128 output: &mut String,
129 event: &zeta_prompt::Event,
130 root_name: &str,
131) {
132 fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
133 for component in strip_root_name(path, root_name).components() {
134 output.push('/');
135 write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
136 }
137 }
138
139 let zeta_prompt::Event::BufferChange {
140 path,
141 old_path,
142 diff,
143 ..
144 } = event;
145
146 output.push_str("--- a");
147 write_relative_path(output, old_path.as_ref(), root_name);
148 output.push_str("\n+++ b");
149 write_relative_path(output, path.as_ref(), root_name);
150 output.push('\n');
151 output.push_str(diff);
152}
153
154fn compute_cursor_excerpt(
155 snapshot: &language::BufferSnapshot,
156 cursor_anchor: language::Anchor,
157) -> (String, usize, Range<Point>) {
158 use text::ToOffset as _;
159 use text::ToPoint as _;
160
161 let cursor_offset = cursor_anchor.to_offset(snapshot);
162 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
163 crate::cursor_excerpt::compute_cursor_excerpt(snapshot, cursor_offset);
164 let syntax_ranges = crate::cursor_excerpt::compute_syntax_ranges(
165 snapshot,
166 cursor_offset,
167 &excerpt_offset_range,
168 );
169 let excerpt_text: String = snapshot.text_for_range(excerpt_point_range).collect();
170 let (_, context_range) = zeta_prompt::compute_editable_and_context_ranges(
171 &excerpt_text,
172 cursor_offset_in_excerpt,
173 &syntax_ranges,
174 100,
175 50,
176 );
177 let context_text = excerpt_text[context_range.clone()].to_string();
178 let cursor_in_context = cursor_offset_in_excerpt.saturating_sub(context_range.start);
179 let context_buffer_start =
180 (excerpt_offset_range.start + context_range.start).to_point(snapshot);
181 let context_buffer_end = (excerpt_offset_range.start + context_range.end).to_point(snapshot);
182 (
183 context_text,
184 cursor_in_context,
185 context_buffer_start..context_buffer_end,
186 )
187}
188
189async fn collect_snapshots(
190 project: &Entity<Project>,
191 git_store: &Entity<project::git_store::GitStore>,
192 worktree_id: WorktreeId,
193 events: &[StoredEvent],
194 cx: &mut gpui::AsyncApp,
195) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
196 let mut snapshots_by_path = HashMap::default();
197 for stored_event in events {
198 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
199 if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
200 let project_path = project
201 .find_project_path(path, cx)
202 .filter(|path| path.worktree_id == worktree_id)?;
203 let relative_path: Arc<Path> = project_path.path.as_std_path().into();
204 Some((project_path, relative_path))
205 }) {
206 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
207 let buffer = project
208 .update(cx, |project, cx| {
209 project.open_buffer(project_path.clone(), cx)
210 })
211 .await?;
212 let diff = git_store
213 .update(cx, |git_store, cx| {
214 git_store.open_uncommitted_diff(buffer.clone(), cx)
215 })
216 .await?;
217 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
218 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
219 }
220 }
221 }
222 Ok(snapshots_by_path)
223}
224
225fn compute_uncommitted_diff(
226 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
227) -> String {
228 let mut uncommitted_diff = String::new();
229 for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
230 if let Some(head_text) = &diff_snapshot.base_text_string() {
231 let file_diff = language::unified_diff(head_text, &before_text.text());
232 if !file_diff.is_empty() {
233 let path_str = relative_path.to_string_lossy();
234 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
235 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
236 uncommitted_diff.push_str(&file_diff);
237 if !uncommitted_diff.ends_with('\n') {
238 uncommitted_diff.push('\n');
239 }
240 }
241 }
242 }
243 uncommitted_diff
244}
245
246fn generate_timestamp_name() -> String {
247 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
248 match format {
249 Ok(format) => {
250 let now = time::OffsetDateTime::now_local()
251 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
252 now.format(&format)
253 .unwrap_or_else(|_| "unknown-time".to_string())
254 }
255 Err(_) => "unknown-time".to_string(),
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use crate::EditPredictionStore;
263 use client::{Client, UserStore};
264 use clock::FakeSystemClock;
265 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
266 use indoc::indoc;
267 use language::{Anchor, Point};
268 use project::{FakeFs, Project};
269 use serde_json::json;
270 use settings::SettingsStore;
271 use std::path::Path;
272
273 #[gpui::test]
274 async fn test_capture_example(cx: &mut TestAppContext) {
275 init_test(cx);
276 let fs = FakeFs::new(cx.executor());
277
278 let committed_contents = indoc! {"
279 fn main() {
280 one();
281 two();
282 three();
283 four();
284 five();
285 six();
286 seven();
287 eight();
288 nine();
289 }
290 "};
291
292 let disk_contents = indoc! {"
293 fn main() {
294 // comment 1
295 one();
296 two();
297 three();
298 four();
299 five();
300 six();
301 seven();
302 eight();
303 // comment 2
304 nine();
305 }
306 "};
307
308 fs.insert_tree(
309 "/project",
310 json!({
311 ".git": {},
312 "src": {
313 "main.rs": disk_contents,
314 }
315 }),
316 )
317 .await;
318
319 // Create an external file outside the main project
320 fs.insert_tree(
321 "/external",
322 json!({
323 "external.rs": "fn external() {}\n",
324 }),
325 )
326 .await;
327
328 fs.set_head_for_repo(
329 Path::new("/project/.git"),
330 &[("src/main.rs", committed_contents.to_string())],
331 "abc123def456",
332 );
333 fs.set_remote_for_repo(
334 Path::new("/project/.git"),
335 "origin",
336 "https://github.com/test/repo.git",
337 );
338
339 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
340
341 let buffer = project
342 .update(cx, |project, cx| {
343 project.open_local_buffer("/project/src/main.rs", cx)
344 })
345 .await
346 .unwrap();
347
348 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
349 ep_store.update(cx, |ep_store, cx| {
350 ep_store.register_buffer(&buffer, &project, cx)
351 });
352 cx.run_until_parked();
353
354 buffer.update(cx, |buffer, cx| {
355 let point = Point::new(6, 0);
356 buffer.edit([(point..point, " // comment 3\n")], None, cx);
357 let point = Point::new(4, 0);
358 buffer.edit([(point..point, " // comment 4\n")], None, cx);
359
360 pretty_assertions::assert_eq!(
361 buffer.text(),
362 indoc! {"
363 fn main() {
364 // comment 1
365 one();
366 two();
367 // comment 4
368 three();
369 four();
370 // comment 3
371 five();
372 six();
373 seven();
374 eight();
375 // comment 2
376 nine();
377 }
378 "}
379 );
380 });
381 cx.run_until_parked();
382
383 // Open and edit an external file (outside the main project's worktree)
384 let external_buffer = project
385 .update(cx, |project, cx| {
386 project.open_local_buffer("/external/external.rs", cx)
387 })
388 .await
389 .unwrap();
390 ep_store.update(cx, |ep_store, cx| {
391 ep_store.register_buffer(&external_buffer, &project, cx)
392 });
393 cx.run_until_parked();
394 external_buffer.update(cx, |buffer, cx| {
395 let point = Point::new(0, 0);
396 buffer.edit([(point..point, "// external edit\n")], None, cx);
397 });
398 cx.run_until_parked();
399
400 // Verify the external edit was recorded in events
401 let events = ep_store.update(cx, |store, cx| store.edit_history_for_project(&project, cx));
402 assert!(
403 matches!(
404 events
405 .last()
406 .unwrap()
407 .event
408 .as_ref(),
409 zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
410 ),
411 "external file edit should be in events"
412 );
413
414 let mut example = cx
415 .update(|cx| {
416 capture_example(
417 project.clone(),
418 buffer.clone(),
419 Anchor::MIN,
420 events,
421 true,
422 cx,
423 )
424 .unwrap()
425 })
426 .await
427 .unwrap();
428 example.name = "test".to_string();
429
430 pretty_assertions::assert_eq!(
431 example,
432 ExampleSpec {
433 name: "test".to_string(),
434 repository_url: "https://github.com/test/repo.git".to_string(),
435 revision: "abc123def456".to_string(),
436 tags: Vec::new(),
437 reasoning: None,
438 uncommitted_diff: indoc! {"
439 --- a/src/main.rs
440 +++ b/src/main.rs
441 @@ -1,4 +1,5 @@
442 fn main() {
443 + // comment 1
444 one();
445 two();
446 three();
447 @@ -7,5 +8,6 @@
448 six();
449 seven();
450 eight();
451 + // comment 2
452 nine();
453 }
454 "}
455 .to_string(),
456 cursor_path: Path::new("src/main.rs").into(),
457 cursor_position: indoc! {"
458 fn main() {
459 ^[CURSOR_POSITION]
460 // comment 1
461 one();
462 two();
463 // comment 4
464 three();
465 four();
466 // comment 3
467 five();
468 six();
469 seven();
470 eight();
471 // comment 2
472 nine();
473 }
474 "}
475 .to_string(),
476 edit_history: indoc! {"
477 --- a/src/main.rs
478 +++ b/src/main.rs
479 @@ -2,8 +2,10 @@
480 // comment 1
481 one();
482 two();
483 + // comment 4
484 three();
485 four();
486 + // comment 3
487 five();
488 six();
489 seven();
490 "}
491 .to_string(),
492 expected_patches: vec![
493 indoc! {"
494 --- a/src/main.rs
495 +++ b/src/main.rs
496 @@ -1,16 +1,16 @@
497 fn main() {
498 // comment 1
499 one();
500 two();
501 // comment 4
502 three();
503 four();
504 // comment 3
505 five();
506 six();
507 seven();
508 eight();
509 // comment 2
510 nine();
511 }
512 "}
513 .to_string()
514 ],
515 rejected_patch: Some(
516 indoc! {"
517 --- a/src/main.rs
518 +++ b/src/main.rs
519 @@ -1,16 +1,16 @@
520 fn main() {
521 // comment 1
522 one();
523 two();
524 // comment 4
525 three();
526 four();
527 // comment 3
528 five();
529 six();
530 seven();
531 eight();
532 // comment 2
533 nine();
534 }
535 "}
536 .to_string()
537 ),
538 telemetry: None,
539 human_feedback: Vec::new(),
540 rating: None,
541 }
542 );
543 }
544
545 fn init_test(cx: &mut TestAppContext) {
546 cx.update(|cx| {
547 let settings_store = SettingsStore::test(cx);
548 cx.set_global(settings_store);
549 zlog::init_test();
550 let http_client = FakeHttpClient::with_404_response();
551 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
552 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
553 language_model::init(cx);
554 RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
555 EditPredictionStore::global(&client, &user_store, cx);
556 })
557 }
558}