1use crate::{StoredEvent, example_spec::ExampleSpec};
2use anyhow::Result;
3use buffer_diff::BufferDiffSnapshot;
4use collections::HashMap;
5use gpui::{App, Entity, Task};
6use language::Buffer;
7use project::{Project, WorktreeId};
8use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
9use text::{BufferSnapshot as TextBufferSnapshot, Point};
10
11pub fn capture_example(
12 project: Entity<Project>,
13 buffer: Entity<Buffer>,
14 cursor_anchor: language::Anchor,
15 mut events: Vec<StoredEvent>,
16 populate_expected_patch: bool,
17 cx: &mut App,
18) -> Option<Task<Result<ExampleSpec>>> {
19 let snapshot = buffer.read(cx).snapshot();
20 let file = snapshot.file()?;
21 let worktree_id = file.worktree_id(cx);
22 let repository = project.read(cx).active_repository(cx)?;
23 let repository_snapshot = repository.read(cx).snapshot();
24 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
25 let root_name = worktree.read(cx).root_name_str().to_owned();
26 let cursor_path: Arc<Path> = file.path().as_std_path().into();
27 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
28 return None;
29 }
30
31 let repository_url = repository_snapshot
32 .remote_origin_url
33 .clone()
34 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
35 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
36
37 let git_store = project.read(cx).git_store().clone();
38
39 Some(cx.spawn(async move |mut cx| {
40 let snapshots_by_path =
41 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
42
43 events.retain(|stored_event| {
44 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
45 let relative_path = strip_root_name(path, &root_name);
46 snapshots_by_path.contains_key(relative_path)
47 });
48
49 let line_comment_prefix = snapshot
50 .language()
51 .and_then(|lang| lang.config().line_comments.first())
52 .map(|s| s.to_string())
53 .unwrap_or_default();
54
55 let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
56 .background_executor()
57 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
58 .await;
59 let uncommitted_diff = cx
60 .background_executor()
61 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
62 .await;
63
64 let mut edit_history = String::new();
65 for stored_event in &events {
66 write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
67 if !edit_history.ends_with('\n') {
68 edit_history.push('\n');
69 }
70 }
71
72 // Initialize an empty patch with context lines, to make it easy
73 // to write the expected patch by hand.
74 let mut expected_patches = Vec::new();
75 let mut rejected_patch = None;
76 if populate_expected_patch {
77 let mut empty_patch = String::new();
78 let start_row = cursor_excerpt_range.start.row + 1;
79 let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
80 writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
81 writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
82 writeln!(
83 &mut empty_patch,
84 "@@ -{},{} +{},{} @@",
85 start_row, row_count, start_row, row_count,
86 )
87 .ok();
88 for line in cursor_excerpt.lines() {
89 writeln!(&mut empty_patch, " {}", line).ok();
90 }
91
92 expected_patches.push(empty_patch.clone());
93 rejected_patch = Some(empty_patch);
94 }
95
96 let mut spec = ExampleSpec {
97 name: generate_timestamp_name(),
98 repository_url,
99 revision,
100 tags: Vec::new(),
101 reasoning: None,
102 uncommitted_diff,
103 cursor_path,
104 cursor_position: String::new(),
105 edit_history,
106 expected_patches,
107 rejected_patch,
108 telemetry: None,
109 human_feedback: Vec::new(),
110 rating: None,
111 };
112 spec.set_cursor_excerpt(
113 &cursor_excerpt,
114 cursor_offset_in_excerpt,
115 &line_comment_prefix,
116 );
117 Ok(spec)
118 }))
119}
120
121fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
122 path.strip_prefix(root_name).unwrap_or(path)
123}
124
125fn write_event_with_relative_paths(
126 output: &mut String,
127 event: &zeta_prompt::Event,
128 root_name: &str,
129) {
130 fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
131 for component in strip_root_name(path, root_name).components() {
132 output.push('/');
133 write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
134 }
135 }
136
137 let zeta_prompt::Event::BufferChange {
138 path,
139 old_path,
140 diff,
141 ..
142 } = event;
143
144 output.push_str("--- a");
145 write_relative_path(output, old_path.as_ref(), root_name);
146 output.push_str("\n+++ b");
147 write_relative_path(output, path.as_ref(), root_name);
148 output.push('\n');
149 output.push_str(diff);
150}
151
152fn compute_cursor_excerpt(
153 snapshot: &language::BufferSnapshot,
154 cursor_anchor: language::Anchor,
155) -> (String, usize, Range<Point>) {
156 use text::ToOffset as _;
157 use text::ToPoint as _;
158
159 let cursor_offset = cursor_anchor.to_offset(snapshot);
160 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
161 crate::cursor_excerpt::compute_cursor_excerpt(snapshot, cursor_offset);
162 let syntax_ranges = crate::cursor_excerpt::compute_syntax_ranges(
163 snapshot,
164 cursor_offset,
165 &excerpt_offset_range,
166 );
167 let excerpt_text: String = snapshot.text_for_range(excerpt_point_range).collect();
168 let (_, context_range) = zeta_prompt::compute_editable_and_context_ranges(
169 &excerpt_text,
170 cursor_offset_in_excerpt,
171 &syntax_ranges,
172 100,
173 50,
174 );
175 let context_text = excerpt_text[context_range.clone()].to_string();
176 let cursor_in_context = cursor_offset_in_excerpt.saturating_sub(context_range.start);
177 let context_buffer_start =
178 (excerpt_offset_range.start + context_range.start).to_point(snapshot);
179 let context_buffer_end = (excerpt_offset_range.start + context_range.end).to_point(snapshot);
180 (
181 context_text,
182 cursor_in_context,
183 context_buffer_start..context_buffer_end,
184 )
185}
186
187async fn collect_snapshots(
188 project: &Entity<Project>,
189 git_store: &Entity<project::git_store::GitStore>,
190 worktree_id: WorktreeId,
191 events: &[StoredEvent],
192 cx: &mut gpui::AsyncApp,
193) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
194 let mut snapshots_by_path = HashMap::default();
195 for stored_event in events {
196 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
197 if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
198 let project_path = project
199 .find_project_path(path, cx)
200 .filter(|path| path.worktree_id == worktree_id)?;
201 let relative_path: Arc<Path> = project_path.path.as_std_path().into();
202 Some((project_path, relative_path))
203 }) {
204 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
205 let buffer = project
206 .update(cx, |project, cx| {
207 project.open_buffer(project_path.clone(), cx)
208 })
209 .await?;
210 let diff = git_store
211 .update(cx, |git_store, cx| {
212 git_store.open_uncommitted_diff(buffer.clone(), cx)
213 })
214 .await?;
215 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
216 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
217 }
218 }
219 }
220 Ok(snapshots_by_path)
221}
222
223fn compute_uncommitted_diff(
224 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
225) -> String {
226 let mut uncommitted_diff = String::new();
227 for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
228 if let Some(head_text) = &diff_snapshot.base_text_string() {
229 let file_diff = language::unified_diff(head_text, &before_text.text());
230 if !file_diff.is_empty() {
231 let path_str = relative_path.to_string_lossy();
232 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
233 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
234 uncommitted_diff.push_str(&file_diff);
235 if !uncommitted_diff.ends_with('\n') {
236 uncommitted_diff.push('\n');
237 }
238 }
239 }
240 }
241 uncommitted_diff
242}
243
244fn generate_timestamp_name() -> String {
245 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
246 match format {
247 Ok(format) => {
248 let now = time::OffsetDateTime::now_local()
249 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
250 now.format(&format)
251 .unwrap_or_else(|_| "unknown-time".to_string())
252 }
253 Err(_) => "unknown-time".to_string(),
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use crate::EditPredictionStore;
261 use client::RefreshLlmTokenListener;
262 use client::{Client, UserStore};
263 use clock::FakeSystemClock;
264 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
265 use indoc::indoc;
266 use language::{Anchor, Point};
267 use project::{FakeFs, Project};
268 use serde_json::json;
269 use settings::SettingsStore;
270 use std::path::Path;
271
272 #[gpui::test]
273 async fn test_capture_example(cx: &mut TestAppContext) {
274 init_test(cx);
275 let fs = FakeFs::new(cx.executor());
276
277 let committed_contents = indoc! {"
278 fn main() {
279 one();
280 two();
281 three();
282 four();
283 five();
284 six();
285 seven();
286 eight();
287 nine();
288 }
289 "};
290
291 let disk_contents = indoc! {"
292 fn main() {
293 // comment 1
294 one();
295 two();
296 three();
297 four();
298 five();
299 six();
300 seven();
301 eight();
302 // comment 2
303 nine();
304 }
305 "};
306
307 fs.insert_tree(
308 "/project",
309 json!({
310 ".git": {},
311 "src": {
312 "main.rs": disk_contents,
313 }
314 }),
315 )
316 .await;
317
318 // Create an external file outside the main project
319 fs.insert_tree(
320 "/external",
321 json!({
322 "external.rs": "fn external() {}\n",
323 }),
324 )
325 .await;
326
327 fs.set_head_for_repo(
328 Path::new("/project/.git"),
329 &[("src/main.rs", committed_contents.to_string())],
330 "abc123def456",
331 );
332 fs.set_remote_for_repo(
333 Path::new("/project/.git"),
334 "origin",
335 "https://github.com/test/repo.git",
336 );
337
338 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
339
340 let buffer = project
341 .update(cx, |project, cx| {
342 project.open_local_buffer("/project/src/main.rs", cx)
343 })
344 .await
345 .unwrap();
346
347 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
348 ep_store.update(cx, |ep_store, cx| {
349 ep_store.register_buffer(&buffer, &project, cx)
350 });
351 cx.run_until_parked();
352
353 buffer.update(cx, |buffer, cx| {
354 let point = Point::new(6, 0);
355 buffer.edit([(point..point, " // comment 3\n")], None, cx);
356 let point = Point::new(4, 0);
357 buffer.edit([(point..point, " // comment 4\n")], None, cx);
358
359 pretty_assertions::assert_eq!(
360 buffer.text(),
361 indoc! {"
362 fn main() {
363 // comment 1
364 one();
365 two();
366 // comment 4
367 three();
368 four();
369 // comment 3
370 five();
371 six();
372 seven();
373 eight();
374 // comment 2
375 nine();
376 }
377 "}
378 );
379 });
380 cx.run_until_parked();
381
382 // Open and edit an external file (outside the main project's worktree)
383 let external_buffer = project
384 .update(cx, |project, cx| {
385 project.open_local_buffer("/external/external.rs", cx)
386 })
387 .await
388 .unwrap();
389 ep_store.update(cx, |ep_store, cx| {
390 ep_store.register_buffer(&external_buffer, &project, cx)
391 });
392 cx.run_until_parked();
393 external_buffer.update(cx, |buffer, cx| {
394 let point = Point::new(0, 0);
395 buffer.edit([(point..point, "// external edit\n")], None, cx);
396 });
397 cx.run_until_parked();
398
399 // Verify the external edit was recorded in events
400 let events = ep_store.update(cx, |store, cx| store.edit_history_for_project(&project, cx));
401 assert!(
402 matches!(
403 events
404 .last()
405 .unwrap()
406 .event
407 .as_ref(),
408 zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
409 ),
410 "external file edit should be in events"
411 );
412
413 let mut example = cx
414 .update(|cx| {
415 capture_example(
416 project.clone(),
417 buffer.clone(),
418 Anchor::min_for_buffer(buffer.read(cx).remote_id()),
419 events,
420 true,
421 cx,
422 )
423 .unwrap()
424 })
425 .await
426 .unwrap();
427 example.name = "test".to_string();
428
429 pretty_assertions::assert_eq!(
430 example,
431 ExampleSpec {
432 name: "test".to_string(),
433 repository_url: "https://github.com/test/repo.git".to_string(),
434 revision: "abc123def456".to_string(),
435 tags: Vec::new(),
436 reasoning: None,
437 uncommitted_diff: indoc! {"
438 --- a/src/main.rs
439 +++ b/src/main.rs
440 @@ -1,4 +1,5 @@
441 fn main() {
442 + // comment 1
443 one();
444 two();
445 three();
446 @@ -7,5 +8,6 @@
447 six();
448 seven();
449 eight();
450 + // comment 2
451 nine();
452 }
453 "}
454 .to_string(),
455 cursor_path: Path::new("src/main.rs").into(),
456 cursor_position: indoc! {"
457 fn main() {
458 ^[CURSOR_POSITION]
459 // comment 1
460 one();
461 two();
462 // comment 4
463 three();
464 four();
465 // comment 3
466 five();
467 six();
468 seven();
469 eight();
470 // comment 2
471 nine();
472 }
473 "}
474 .to_string(),
475 edit_history: indoc! {"
476 --- a/src/main.rs
477 +++ b/src/main.rs
478 @@ -2,8 +2,10 @@
479 // comment 1
480 one();
481 two();
482 + // comment 4
483 three();
484 four();
485 + // comment 3
486 five();
487 six();
488 seven();
489 "}
490 .to_string(),
491 expected_patches: vec![
492 indoc! {"
493 --- a/src/main.rs
494 +++ b/src/main.rs
495 @@ -1,16 +1,16 @@
496 fn main() {
497 // comment 1
498 one();
499 two();
500 // comment 4
501 three();
502 four();
503 // comment 3
504 five();
505 six();
506 seven();
507 eight();
508 // comment 2
509 nine();
510 }
511 "}
512 .to_string()
513 ],
514 rejected_patch: Some(
515 indoc! {"
516 --- a/src/main.rs
517 +++ b/src/main.rs
518 @@ -1,16 +1,16 @@
519 fn main() {
520 // comment 1
521 one();
522 two();
523 // comment 4
524 three();
525 four();
526 // comment 3
527 five();
528 six();
529 seven();
530 eight();
531 // comment 2
532 nine();
533 }
534 "}
535 .to_string()
536 ),
537 telemetry: None,
538 human_feedback: Vec::new(),
539 rating: None,
540 }
541 );
542 }
543
544 fn init_test(cx: &mut TestAppContext) {
545 cx.update(|cx| {
546 let settings_store = SettingsStore::test(cx);
547 cx.set_global(settings_store);
548 zlog::init_test();
549 let http_client = FakeHttpClient::with_404_response();
550 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
551 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
552 language_model::init(cx);
553 RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
554 EditPredictionStore::global(&client, &user_store, cx);
555 })
556 }
557}