1use crate::{
2 StoredEvent, cursor_excerpt::editable_and_context_ranges_for_cursor_position,
3 example_spec::ExampleSpec,
4};
5use anyhow::Result;
6use buffer_diff::BufferDiffSnapshot;
7use collections::HashMap;
8use gpui::{App, Entity, Task};
9use language::{Buffer, ToPoint as _};
10use project::{Project, WorktreeId};
11use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
12use text::{BufferSnapshot as TextBufferSnapshot, Point};
13
14pub fn capture_example(
15 project: Entity<Project>,
16 buffer: Entity<Buffer>,
17 cursor_anchor: language::Anchor,
18 mut events: Vec<StoredEvent>,
19 populate_expected_patch: bool,
20 cx: &mut App,
21) -> Option<Task<Result<ExampleSpec>>> {
22 let snapshot = buffer.read(cx).snapshot();
23 let file = snapshot.file()?;
24 let worktree_id = file.worktree_id(cx);
25 let repository = project.read(cx).active_repository(cx)?;
26 let repository_snapshot = repository.read(cx).snapshot();
27 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
28 let root_name = worktree.read(cx).root_name_str().to_owned();
29 let cursor_path: Arc<Path> = file.path().as_std_path().into();
30 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
31 return None;
32 }
33
34 let repository_url = repository_snapshot
35 .remote_origin_url
36 .clone()
37 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
38 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
39
40 let git_store = project.read(cx).git_store().clone();
41
42 Some(cx.spawn(async move |mut cx| {
43 let snapshots_by_path =
44 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
45
46 events.retain(|stored_event| {
47 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
48 let relative_path = strip_root_name(path, &root_name);
49 snapshots_by_path.contains_key(relative_path)
50 });
51
52 let line_comment_prefix = snapshot
53 .language()
54 .and_then(|lang| lang.config().line_comments.first())
55 .map(|s| s.to_string())
56 .unwrap_or_default();
57
58 let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
59 .background_executor()
60 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
61 .await;
62 let uncommitted_diff = cx
63 .background_executor()
64 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
65 .await;
66
67 let mut edit_history = String::new();
68 for stored_event in &events {
69 write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
70 if !edit_history.ends_with('\n') {
71 edit_history.push('\n');
72 }
73 }
74
75 // Initialize an empty patch with context lines, to make it easy
76 // to write the expected patch by hand.
77 let mut expected_patches = Vec::new();
78 let mut rejected_patch = None;
79 if populate_expected_patch {
80 let mut empty_patch = String::new();
81 let start_row = cursor_excerpt_range.start.row + 1;
82 let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
83 writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
84 writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
85 writeln!(
86 &mut empty_patch,
87 "@@ -{},{} +{},{} @@",
88 start_row, row_count, start_row, row_count,
89 )
90 .ok();
91 for line in cursor_excerpt.lines() {
92 writeln!(&mut empty_patch, " {}", line).ok();
93 }
94
95 expected_patches.push(empty_patch.clone());
96 rejected_patch = Some(empty_patch);
97 }
98
99 let mut spec = ExampleSpec {
100 name: generate_timestamp_name(),
101 repository_url,
102 revision,
103 tags: Vec::new(),
104 reasoning: None,
105 uncommitted_diff,
106 cursor_path,
107 cursor_position: String::new(),
108 edit_history,
109 expected_patches,
110 rejected_patch,
111 telemetry: None,
112 human_feedback: Vec::new(),
113 rating: None,
114 };
115 spec.set_cursor_excerpt(
116 &cursor_excerpt,
117 cursor_offset_in_excerpt,
118 &line_comment_prefix,
119 );
120 Ok(spec)
121 }))
122}
123
124fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
125 path.strip_prefix(root_name).unwrap_or(path)
126}
127
128fn write_event_with_relative_paths(
129 output: &mut String,
130 event: &zeta_prompt::Event,
131 root_name: &str,
132) {
133 fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
134 for component in strip_root_name(path, root_name).components() {
135 output.push('/');
136 write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
137 }
138 }
139
140 let zeta_prompt::Event::BufferChange {
141 path,
142 old_path,
143 diff,
144 ..
145 } = event;
146
147 output.push_str("--- a");
148 write_relative_path(output, old_path.as_ref(), root_name);
149 output.push_str("\n+++ b");
150 write_relative_path(output, path.as_ref(), root_name);
151 output.push('\n');
152 output.push_str(diff);
153}
154
155fn compute_cursor_excerpt(
156 snapshot: &language::BufferSnapshot,
157 cursor_anchor: language::Anchor,
158) -> (String, usize, Range<Point>) {
159 use text::ToOffset as _;
160
161 let cursor_point = cursor_anchor.to_point(snapshot);
162 let (_editable_range, context_range) =
163 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
164 let context_start_offset = context_range.start.to_offset(snapshot);
165 let cursor_offset = cursor_anchor.to_offset(snapshot);
166 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
167 let excerpt = snapshot
168 .text_for_range(context_range.clone())
169 .collect::<String>();
170 (excerpt, cursor_offset_in_excerpt, context_range)
171}
172
173async fn collect_snapshots(
174 project: &Entity<Project>,
175 git_store: &Entity<project::git_store::GitStore>,
176 worktree_id: WorktreeId,
177 events: &[StoredEvent],
178 cx: &mut gpui::AsyncApp,
179) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
180 let mut snapshots_by_path = HashMap::default();
181 for stored_event in events {
182 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
183 if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
184 let project_path = project
185 .find_project_path(path, cx)
186 .filter(|path| path.worktree_id == worktree_id)?;
187 let relative_path: Arc<Path> = project_path.path.as_std_path().into();
188 Some((project_path, relative_path))
189 }) {
190 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
191 let buffer = project
192 .update(cx, |project, cx| {
193 project.open_buffer(project_path.clone(), cx)
194 })
195 .await?;
196 let diff = git_store
197 .update(cx, |git_store, cx| {
198 git_store.open_uncommitted_diff(buffer.clone(), cx)
199 })
200 .await?;
201 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
202 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
203 }
204 }
205 }
206 Ok(snapshots_by_path)
207}
208
209fn compute_uncommitted_diff(
210 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
211) -> String {
212 let mut uncommitted_diff = String::new();
213 for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
214 if let Some(head_text) = &diff_snapshot.base_text_string() {
215 let file_diff = language::unified_diff(head_text, &before_text.text());
216 if !file_diff.is_empty() {
217 let path_str = relative_path.to_string_lossy();
218 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
219 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
220 uncommitted_diff.push_str(&file_diff);
221 if !uncommitted_diff.ends_with('\n') {
222 uncommitted_diff.push('\n');
223 }
224 }
225 }
226 }
227 uncommitted_diff
228}
229
230fn generate_timestamp_name() -> String {
231 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
232 match format {
233 Ok(format) => {
234 let now = time::OffsetDateTime::now_local()
235 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
236 now.format(&format)
237 .unwrap_or_else(|_| "unknown-time".to_string())
238 }
239 Err(_) => "unknown-time".to_string(),
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use crate::EditPredictionStore;
247 use client::{Client, UserStore};
248 use clock::FakeSystemClock;
249 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
250 use indoc::indoc;
251 use language::{Anchor, Point};
252 use project::{FakeFs, Project};
253 use serde_json::json;
254 use settings::SettingsStore;
255 use std::path::Path;
256
257 #[gpui::test]
258 async fn test_capture_example(cx: &mut TestAppContext) {
259 init_test(cx);
260 let fs = FakeFs::new(cx.executor());
261
262 let committed_contents = indoc! {"
263 fn main() {
264 one();
265 two();
266 three();
267 four();
268 five();
269 six();
270 seven();
271 eight();
272 nine();
273 }
274 "};
275
276 let disk_contents = indoc! {"
277 fn main() {
278 // comment 1
279 one();
280 two();
281 three();
282 four();
283 five();
284 six();
285 seven();
286 eight();
287 // comment 2
288 nine();
289 }
290 "};
291
292 fs.insert_tree(
293 "/project",
294 json!({
295 ".git": {},
296 "src": {
297 "main.rs": disk_contents,
298 }
299 }),
300 )
301 .await;
302
303 // Create an external file outside the main project
304 fs.insert_tree(
305 "/external",
306 json!({
307 "external.rs": "fn external() {}\n",
308 }),
309 )
310 .await;
311
312 fs.set_head_for_repo(
313 Path::new("/project/.git"),
314 &[("src/main.rs", committed_contents.to_string())],
315 "abc123def456",
316 );
317 fs.set_remote_for_repo(
318 Path::new("/project/.git"),
319 "origin",
320 "https://github.com/test/repo.git",
321 );
322
323 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
324
325 let buffer = project
326 .update(cx, |project, cx| {
327 project.open_local_buffer("/project/src/main.rs", cx)
328 })
329 .await
330 .unwrap();
331
332 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
333 ep_store.update(cx, |ep_store, cx| {
334 ep_store.register_buffer(&buffer, &project, cx)
335 });
336 cx.run_until_parked();
337
338 buffer.update(cx, |buffer, cx| {
339 let point = Point::new(6, 0);
340 buffer.edit([(point..point, " // comment 3\n")], None, cx);
341 let point = Point::new(4, 0);
342 buffer.edit([(point..point, " // comment 4\n")], None, cx);
343
344 pretty_assertions::assert_eq!(
345 buffer.text(),
346 indoc! {"
347 fn main() {
348 // comment 1
349 one();
350 two();
351 // comment 4
352 three();
353 four();
354 // comment 3
355 five();
356 six();
357 seven();
358 eight();
359 // comment 2
360 nine();
361 }
362 "}
363 );
364 });
365 cx.run_until_parked();
366
367 // Open and edit an external file (outside the main project's worktree)
368 let external_buffer = project
369 .update(cx, |project, cx| {
370 project.open_local_buffer("/external/external.rs", cx)
371 })
372 .await
373 .unwrap();
374 ep_store.update(cx, |ep_store, cx| {
375 ep_store.register_buffer(&external_buffer, &project, cx)
376 });
377 cx.run_until_parked();
378 external_buffer.update(cx, |buffer, cx| {
379 let point = Point::new(0, 0);
380 buffer.edit([(point..point, "// external edit\n")], None, cx);
381 });
382 cx.run_until_parked();
383
384 // Verify the external edit was recorded in events
385 let events = ep_store.update(cx, |store, cx| store.edit_history_for_project(&project, cx));
386 assert!(
387 matches!(
388 events
389 .last()
390 .unwrap()
391 .event
392 .as_ref(),
393 zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
394 ),
395 "external file edit should be in events"
396 );
397
398 let mut example = cx
399 .update(|cx| {
400 capture_example(
401 project.clone(),
402 buffer.clone(),
403 Anchor::MIN,
404 events,
405 true,
406 cx,
407 )
408 .unwrap()
409 })
410 .await
411 .unwrap();
412 example.name = "test".to_string();
413
414 pretty_assertions::assert_eq!(
415 example,
416 ExampleSpec {
417 name: "test".to_string(),
418 repository_url: "https://github.com/test/repo.git".to_string(),
419 revision: "abc123def456".to_string(),
420 tags: Vec::new(),
421 reasoning: None,
422 uncommitted_diff: indoc! {"
423 --- a/src/main.rs
424 +++ b/src/main.rs
425 @@ -1,4 +1,5 @@
426 fn main() {
427 + // comment 1
428 one();
429 two();
430 three();
431 @@ -7,5 +8,6 @@
432 six();
433 seven();
434 eight();
435 + // comment 2
436 nine();
437 }
438 "}
439 .to_string(),
440 cursor_path: Path::new("src/main.rs").into(),
441 cursor_position: indoc! {"
442 fn main() {
443 ^[CURSOR_POSITION]
444 // comment 1
445 one();
446 two();
447 // comment 4
448 three();
449 four();
450 // comment 3
451 five();
452 six();
453 seven();
454 eight();
455 // comment 2
456 nine();
457 }
458 "}
459 .to_string(),
460 edit_history: indoc! {"
461 --- a/src/main.rs
462 +++ b/src/main.rs
463 @@ -2,8 +2,10 @@
464 // comment 1
465 one();
466 two();
467 + // comment 4
468 three();
469 four();
470 + // comment 3
471 five();
472 six();
473 seven();
474 "}
475 .to_string(),
476 expected_patches: vec![
477 indoc! {"
478 --- a/src/main.rs
479 +++ b/src/main.rs
480 @@ -1,16 +1,16 @@
481 fn main() {
482 // comment 1
483 one();
484 two();
485 // comment 4
486 three();
487 four();
488 // comment 3
489 five();
490 six();
491 seven();
492 eight();
493 // comment 2
494 nine();
495 }
496 "}
497 .to_string()
498 ],
499 rejected_patch: Some(
500 indoc! {"
501 --- a/src/main.rs
502 +++ b/src/main.rs
503 @@ -1,16 +1,16 @@
504 fn main() {
505 // comment 1
506 one();
507 two();
508 // comment 4
509 three();
510 four();
511 // comment 3
512 five();
513 six();
514 seven();
515 eight();
516 // comment 2
517 nine();
518 }
519 "}
520 .to_string()
521 ),
522 telemetry: None,
523 human_feedback: Vec::new(),
524 rating: None,
525 }
526 );
527 }
528
529 fn init_test(cx: &mut TestAppContext) {
530 cx.update(|cx| {
531 let settings_store = SettingsStore::test(cx);
532 cx.set_global(settings_store);
533 zlog::init_test();
534 let http_client = FakeHttpClient::with_404_response();
535 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
536 language_model::init(client.clone(), cx);
537 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
538 EditPredictionStore::global(&client, &user_store, cx);
539 })
540 }
541}