capture_example.rs

  1use crate::{StoredEvent, example_spec::ExampleSpec};
  2use anyhow::Result;
  3use buffer_diff::BufferDiffSnapshot;
  4#[cfg(test)]
  5use client::RefreshLlmTokenListener;
  6use collections::HashMap;
  7use gpui::{App, Entity, Task};
  8use language::Buffer;
  9use project::{Project, WorktreeId};
 10use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
 11use text::{BufferSnapshot as TextBufferSnapshot, Point};
 12
 13pub fn capture_example(
 14    project: Entity<Project>,
 15    buffer: Entity<Buffer>,
 16    cursor_anchor: language::Anchor,
 17    mut events: Vec<StoredEvent>,
 18    populate_expected_patch: bool,
 19    cx: &mut App,
 20) -> Option<Task<Result<ExampleSpec>>> {
 21    let snapshot = buffer.read(cx).snapshot();
 22    let file = snapshot.file()?;
 23    let worktree_id = file.worktree_id(cx);
 24    let repository = project.read(cx).active_repository(cx)?;
 25    let repository_snapshot = repository.read(cx).snapshot();
 26    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 27    let root_name = worktree.read(cx).root_name_str().to_owned();
 28    let cursor_path: Arc<Path> = file.path().as_std_path().into();
 29    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 30        return None;
 31    }
 32
 33    let repository_url = repository_snapshot
 34        .remote_origin_url
 35        .clone()
 36        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 37    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 38
 39    let git_store = project.read(cx).git_store().clone();
 40
 41    Some(cx.spawn(async move |mut cx| {
 42        let snapshots_by_path =
 43            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 44
 45        events.retain(|stored_event| {
 46            let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
 47            let relative_path = strip_root_name(path, &root_name);
 48            snapshots_by_path.contains_key(relative_path)
 49        });
 50
 51        let line_comment_prefix = snapshot
 52            .language()
 53            .and_then(|lang| lang.config().line_comments.first())
 54            .map(|s| s.to_string())
 55            .unwrap_or_default();
 56
 57        let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
 58            .background_executor()
 59            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 60            .await;
 61        let uncommitted_diff = cx
 62            .background_executor()
 63            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 64            .await;
 65
 66        let mut edit_history = String::new();
 67        for stored_event in &events {
 68            write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
 69            if !edit_history.ends_with('\n') {
 70                edit_history.push('\n');
 71            }
 72        }
 73
 74        // Initialize an empty patch with context lines, to make it easy
 75        // to write the expected patch by hand.
 76        let mut expected_patches = Vec::new();
 77        let mut rejected_patch = None;
 78        if populate_expected_patch {
 79            let mut empty_patch = String::new();
 80            let start_row = cursor_excerpt_range.start.row + 1;
 81            let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
 82            writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
 83            writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
 84            writeln!(
 85                &mut empty_patch,
 86                "@@ -{},{} +{},{} @@",
 87                start_row, row_count, start_row, row_count,
 88            )
 89            .ok();
 90            for line in cursor_excerpt.lines() {
 91                writeln!(&mut empty_patch, " {}", line).ok();
 92            }
 93
 94            expected_patches.push(empty_patch.clone());
 95            rejected_patch = Some(empty_patch);
 96        }
 97
 98        let mut spec = ExampleSpec {
 99            name: generate_timestamp_name(),
100            repository_url,
101            revision,
102            tags: Vec::new(),
103            reasoning: None,
104            uncommitted_diff,
105            cursor_path,
106            cursor_position: String::new(),
107            edit_history,
108            expected_patches,
109            rejected_patch,
110            telemetry: None,
111            human_feedback: Vec::new(),
112            rating: None,
113        };
114        spec.set_cursor_excerpt(
115            &cursor_excerpt,
116            cursor_offset_in_excerpt,
117            &line_comment_prefix,
118        );
119        Ok(spec)
120    }))
121}
122
123fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
124    path.strip_prefix(root_name).unwrap_or(path)
125}
126
127fn write_event_with_relative_paths(
128    output: &mut String,
129    event: &zeta_prompt::Event,
130    root_name: &str,
131) {
132    fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
133        for component in strip_root_name(path, root_name).components() {
134            output.push('/');
135            write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
136        }
137    }
138
139    let zeta_prompt::Event::BufferChange {
140        path,
141        old_path,
142        diff,
143        ..
144    } = event;
145
146    output.push_str("--- a");
147    write_relative_path(output, old_path.as_ref(), root_name);
148    output.push_str("\n+++ b");
149    write_relative_path(output, path.as_ref(), root_name);
150    output.push('\n');
151    output.push_str(diff);
152}
153
154fn compute_cursor_excerpt(
155    snapshot: &language::BufferSnapshot,
156    cursor_anchor: language::Anchor,
157) -> (String, usize, Range<Point>) {
158    use text::ToOffset as _;
159    use text::ToPoint as _;
160
161    let cursor_offset = cursor_anchor.to_offset(snapshot);
162    let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
163        crate::cursor_excerpt::compute_cursor_excerpt(snapshot, cursor_offset);
164    let syntax_ranges = crate::cursor_excerpt::compute_syntax_ranges(
165        snapshot,
166        cursor_offset,
167        &excerpt_offset_range,
168    );
169    let excerpt_text: String = snapshot.text_for_range(excerpt_point_range).collect();
170    let (_, context_range) = zeta_prompt::compute_editable_and_context_ranges(
171        &excerpt_text,
172        cursor_offset_in_excerpt,
173        &syntax_ranges,
174        100,
175        50,
176    );
177    let context_text = excerpt_text[context_range.clone()].to_string();
178    let cursor_in_context = cursor_offset_in_excerpt.saturating_sub(context_range.start);
179    let context_buffer_start =
180        (excerpt_offset_range.start + context_range.start).to_point(snapshot);
181    let context_buffer_end = (excerpt_offset_range.start + context_range.end).to_point(snapshot);
182    (
183        context_text,
184        cursor_in_context,
185        context_buffer_start..context_buffer_end,
186    )
187}
188
189async fn collect_snapshots(
190    project: &Entity<Project>,
191    git_store: &Entity<project::git_store::GitStore>,
192    worktree_id: WorktreeId,
193    events: &[StoredEvent],
194    cx: &mut gpui::AsyncApp,
195) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
196    let mut snapshots_by_path = HashMap::default();
197    for stored_event in events {
198        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
199        if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
200            let project_path = project
201                .find_project_path(path, cx)
202                .filter(|path| path.worktree_id == worktree_id)?;
203            let relative_path: Arc<Path> = project_path.path.as_std_path().into();
204            Some((project_path, relative_path))
205        }) {
206            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
207                let buffer = project
208                    .update(cx, |project, cx| {
209                        project.open_buffer(project_path.clone(), cx)
210                    })
211                    .await?;
212                let diff = git_store
213                    .update(cx, |git_store, cx| {
214                        git_store.open_uncommitted_diff(buffer.clone(), cx)
215                    })
216                    .await?;
217                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
218                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
219            }
220        }
221    }
222    Ok(snapshots_by_path)
223}
224
225fn compute_uncommitted_diff(
226    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
227) -> String {
228    let mut uncommitted_diff = String::new();
229    for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
230        if let Some(head_text) = &diff_snapshot.base_text_string() {
231            let file_diff = language::unified_diff(head_text, &before_text.text());
232            if !file_diff.is_empty() {
233                let path_str = relative_path.to_string_lossy();
234                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
235                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
236                uncommitted_diff.push_str(&file_diff);
237                if !uncommitted_diff.ends_with('\n') {
238                    uncommitted_diff.push('\n');
239                }
240            }
241        }
242    }
243    uncommitted_diff
244}
245
246fn generate_timestamp_name() -> String {
247    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
248    match format {
249        Ok(format) => {
250            let now = time::OffsetDateTime::now_local()
251                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
252            now.format(&format)
253                .unwrap_or_else(|_| "unknown-time".to_string())
254        }
255        Err(_) => "unknown-time".to_string(),
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use crate::EditPredictionStore;
263    use client::{Client, UserStore};
264    use clock::FakeSystemClock;
265    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
266    use indoc::indoc;
267    use language::{Anchor, Point};
268    use project::{FakeFs, Project};
269    use serde_json::json;
270    use settings::SettingsStore;
271    use std::path::Path;
272
273    #[gpui::test]
274    async fn test_capture_example(cx: &mut TestAppContext) {
275        init_test(cx);
276        let fs = FakeFs::new(cx.executor());
277
278        let committed_contents = indoc! {"
279            fn main() {
280                one();
281                two();
282                three();
283                four();
284                five();
285                six();
286                seven();
287                eight();
288                nine();
289            }
290        "};
291
292        let disk_contents = indoc! {"
293            fn main() {
294                // comment 1
295                one();
296                two();
297                three();
298                four();
299                five();
300                six();
301                seven();
302                eight();
303                // comment 2
304                nine();
305            }
306        "};
307
308        fs.insert_tree(
309            "/project",
310            json!({
311                ".git": {},
312                "src": {
313                    "main.rs": disk_contents,
314                }
315            }),
316        )
317        .await;
318
319        // Create an external file outside the main project
320        fs.insert_tree(
321            "/external",
322            json!({
323                "external.rs": "fn external() {}\n",
324            }),
325        )
326        .await;
327
328        fs.set_head_for_repo(
329            Path::new("/project/.git"),
330            &[("src/main.rs", committed_contents.to_string())],
331            "abc123def456",
332        );
333        fs.set_remote_for_repo(
334            Path::new("/project/.git"),
335            "origin",
336            "https://github.com/test/repo.git",
337        );
338
339        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
340
341        let buffer = project
342            .update(cx, |project, cx| {
343                project.open_local_buffer("/project/src/main.rs", cx)
344            })
345            .await
346            .unwrap();
347
348        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
349        ep_store.update(cx, |ep_store, cx| {
350            ep_store.register_buffer(&buffer, &project, cx)
351        });
352        cx.run_until_parked();
353
354        buffer.update(cx, |buffer, cx| {
355            let point = Point::new(6, 0);
356            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
357            let point = Point::new(4, 0);
358            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
359
360            pretty_assertions::assert_eq!(
361                buffer.text(),
362                indoc! {"
363                    fn main() {
364                        // comment 1
365                        one();
366                        two();
367                        // comment 4
368                        three();
369                        four();
370                        // comment 3
371                        five();
372                        six();
373                        seven();
374                        eight();
375                        // comment 2
376                        nine();
377                    }
378                "}
379            );
380        });
381        cx.run_until_parked();
382
383        // Open and edit an external file (outside the main project's worktree)
384        let external_buffer = project
385            .update(cx, |project, cx| {
386                project.open_local_buffer("/external/external.rs", cx)
387            })
388            .await
389            .unwrap();
390        ep_store.update(cx, |ep_store, cx| {
391            ep_store.register_buffer(&external_buffer, &project, cx)
392        });
393        cx.run_until_parked();
394        external_buffer.update(cx, |buffer, cx| {
395            let point = Point::new(0, 0);
396            buffer.edit([(point..point, "// external edit\n")], None, cx);
397        });
398        cx.run_until_parked();
399
400        // Verify the external edit was recorded in events
401        let events = ep_store.update(cx, |store, cx| store.edit_history_for_project(&project, cx));
402        assert!(
403            matches!(
404                events
405                    .last()
406                    .unwrap()
407                    .event
408                    .as_ref(),
409                zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
410            ),
411            "external file edit should be in events"
412        );
413
414        let mut example = cx
415            .update(|cx| {
416                capture_example(
417                    project.clone(),
418                    buffer.clone(),
419                    Anchor::MIN,
420                    events,
421                    true,
422                    cx,
423                )
424                .unwrap()
425            })
426            .await
427            .unwrap();
428        example.name = "test".to_string();
429
430        pretty_assertions::assert_eq!(
431            example,
432            ExampleSpec {
433                name: "test".to_string(),
434                repository_url: "https://github.com/test/repo.git".to_string(),
435                revision: "abc123def456".to_string(),
436                tags: Vec::new(),
437                reasoning: None,
438                uncommitted_diff: indoc! {"
439                    --- a/src/main.rs
440                    +++ b/src/main.rs
441                    @@ -1,4 +1,5 @@
442                     fn main() {
443                    +    // comment 1
444                         one();
445                         two();
446                         three();
447                    @@ -7,5 +8,6 @@
448                         six();
449                         seven();
450                         eight();
451                    +    // comment 2
452                         nine();
453                     }
454                "}
455                .to_string(),
456                cursor_path: Path::new("src/main.rs").into(),
457                cursor_position: indoc! {"
458                    fn main() {
459                    ^[CURSOR_POSITION]
460                        // comment 1
461                        one();
462                        two();
463                        // comment 4
464                        three();
465                        four();
466                        // comment 3
467                        five();
468                        six();
469                        seven();
470                        eight();
471                        // comment 2
472                        nine();
473                    }
474                "}
475                .to_string(),
476                edit_history: indoc! {"
477                    --- a/src/main.rs
478                    +++ b/src/main.rs
479                    @@ -2,8 +2,10 @@
480                         // comment 1
481                         one();
482                         two();
483                    +    // comment 4
484                         three();
485                         four();
486                    +    // comment 3
487                         five();
488                         six();
489                         seven();
490                "}
491                .to_string(),
492                expected_patches: vec![
493                    indoc! {"
494                        --- a/src/main.rs
495                        +++ b/src/main.rs
496                        @@ -1,16 +1,16 @@
497                         fn main() {
498                             // comment 1
499                             one();
500                             two();
501                             // comment 4
502                             three();
503                             four();
504                             // comment 3
505                             five();
506                             six();
507                             seven();
508                             eight();
509                             // comment 2
510                             nine();
511                         }
512                    "}
513                    .to_string()
514                ],
515                rejected_patch: Some(
516                    indoc! {"
517                        --- a/src/main.rs
518                        +++ b/src/main.rs
519                        @@ -1,16 +1,16 @@
520                         fn main() {
521                             // comment 1
522                             one();
523                             two();
524                             // comment 4
525                             three();
526                             four();
527                             // comment 3
528                             five();
529                             six();
530                             seven();
531                             eight();
532                             // comment 2
533                             nine();
534                         }
535                    "}
536                    .to_string()
537                ),
538                telemetry: None,
539                human_feedback: Vec::new(),
540                rating: None,
541            }
542        );
543    }
544
545    fn init_test(cx: &mut TestAppContext) {
546        cx.update(|cx| {
547            let settings_store = SettingsStore::test(cx);
548            cx.set_global(settings_store);
549            zlog::init_test();
550            let http_client = FakeHttpClient::with_404_response();
551            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
552            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
553            language_model::init(cx);
554            RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
555            EditPredictionStore::global(&client, &user_store, cx);
556        })
557    }
558}