capture_example.rs

  1use crate::{StoredEvent, example_spec::ExampleSpec};
  2use anyhow::Result;
  3use buffer_diff::BufferDiffSnapshot;
  4use collections::HashMap;
  5use gpui::{App, Entity, Task};
  6use language::Buffer;
  7use project::{Project, WorktreeId};
  8use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
  9use text::{BufferSnapshot as TextBufferSnapshot, Point};
 10
 11pub fn capture_example(
 12    project: Entity<Project>,
 13    buffer: Entity<Buffer>,
 14    cursor_anchor: language::Anchor,
 15    mut events: Vec<StoredEvent>,
 16    populate_expected_patch: bool,
 17    cx: &mut App,
 18) -> Option<Task<Result<ExampleSpec>>> {
 19    let snapshot = buffer.read(cx).snapshot();
 20    let file = snapshot.file()?;
 21    let worktree_id = file.worktree_id(cx);
 22    let repository = project.read(cx).active_repository(cx)?;
 23    let repository_snapshot = repository.read(cx).snapshot();
 24    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 25    let root_name = worktree.read(cx).root_name_str().to_owned();
 26    let cursor_path: Arc<Path> = file.path().as_std_path().into();
 27    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 28        return None;
 29    }
 30
 31    let repository_url = repository_snapshot
 32        .remote_origin_url
 33        .clone()
 34        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 35    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 36
 37    let git_store = project.read(cx).git_store().clone();
 38
 39    Some(cx.spawn(async move |mut cx| {
 40        let snapshots_by_path =
 41            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 42
 43        events.retain(|stored_event| {
 44            let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
 45            let relative_path = strip_root_name(path, &root_name);
 46            snapshots_by_path.contains_key(relative_path)
 47        });
 48
 49        let line_comment_prefix = snapshot
 50            .language()
 51            .and_then(|lang| lang.config().line_comments.first())
 52            .map(|s| s.to_string())
 53            .unwrap_or_default();
 54
 55        let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
 56            .background_executor()
 57            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 58            .await;
 59        let uncommitted_diff = cx
 60            .background_executor()
 61            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 62            .await;
 63
 64        let mut edit_history = String::new();
 65        for stored_event in &events {
 66            write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
 67            if !edit_history.ends_with('\n') {
 68                edit_history.push('\n');
 69            }
 70        }
 71
 72        // Initialize an empty patch with context lines, to make it easy
 73        // to write the expected patch by hand.
 74        let mut expected_patches = Vec::new();
 75        let mut rejected_patch = None;
 76        if populate_expected_patch {
 77            let mut empty_patch = String::new();
 78            let start_row = cursor_excerpt_range.start.row + 1;
 79            let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
 80            writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
 81            writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
 82            writeln!(
 83                &mut empty_patch,
 84                "@@ -{},{} +{},{} @@",
 85                start_row, row_count, start_row, row_count,
 86            )
 87            .ok();
 88            for line in cursor_excerpt.lines() {
 89                writeln!(&mut empty_patch, " {}", line).ok();
 90            }
 91
 92            expected_patches.push(empty_patch.clone());
 93            rejected_patch = Some(empty_patch);
 94        }
 95
 96        let mut spec = ExampleSpec {
 97            name: generate_timestamp_name(),
 98            repository_url,
 99            revision,
100            tags: Vec::new(),
101            reasoning: None,
102            uncommitted_diff,
103            cursor_path,
104            cursor_position: String::new(),
105            edit_history,
106            expected_patches,
107            rejected_patch,
108            telemetry: None,
109            human_feedback: Vec::new(),
110            rating: None,
111        };
112        spec.set_cursor_excerpt(
113            &cursor_excerpt,
114            cursor_offset_in_excerpt,
115            &line_comment_prefix,
116        );
117        Ok(spec)
118    }))
119}
120
121fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
122    path.strip_prefix(root_name).unwrap_or(path)
123}
124
125fn write_event_with_relative_paths(
126    output: &mut String,
127    event: &zeta_prompt::Event,
128    root_name: &str,
129) {
130    fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
131        for component in strip_root_name(path, root_name).components() {
132            output.push('/');
133            write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
134        }
135    }
136
137    let zeta_prompt::Event::BufferChange {
138        path,
139        old_path,
140        diff,
141        ..
142    } = event;
143
144    output.push_str("--- a");
145    write_relative_path(output, old_path.as_ref(), root_name);
146    output.push_str("\n+++ b");
147    write_relative_path(output, path.as_ref(), root_name);
148    output.push('\n');
149    output.push_str(diff);
150}
151
152fn compute_cursor_excerpt(
153    snapshot: &language::BufferSnapshot,
154    cursor_anchor: language::Anchor,
155) -> (String, usize, Range<Point>) {
156    use text::ToOffset as _;
157    use text::ToPoint as _;
158
159    let cursor_offset = cursor_anchor.to_offset(snapshot);
160    let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
161        crate::cursor_excerpt::compute_cursor_excerpt(snapshot, cursor_offset);
162    let syntax_ranges = crate::cursor_excerpt::compute_syntax_ranges(
163        snapshot,
164        cursor_offset,
165        &excerpt_offset_range,
166    );
167    let excerpt_text: String = snapshot.text_for_range(excerpt_point_range).collect();
168    let (_, context_range) = zeta_prompt::compute_editable_and_context_ranges(
169        &excerpt_text,
170        cursor_offset_in_excerpt,
171        &syntax_ranges,
172        100,
173        50,
174    );
175    let context_text = excerpt_text[context_range.clone()].to_string();
176    let cursor_in_context = cursor_offset_in_excerpt.saturating_sub(context_range.start);
177    let context_buffer_start =
178        (excerpt_offset_range.start + context_range.start).to_point(snapshot);
179    let context_buffer_end = (excerpt_offset_range.start + context_range.end).to_point(snapshot);
180    (
181        context_text,
182        cursor_in_context,
183        context_buffer_start..context_buffer_end,
184    )
185}
186
187async fn collect_snapshots(
188    project: &Entity<Project>,
189    git_store: &Entity<project::git_store::GitStore>,
190    worktree_id: WorktreeId,
191    events: &[StoredEvent],
192    cx: &mut gpui::AsyncApp,
193) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
194    let mut snapshots_by_path = HashMap::default();
195    for stored_event in events {
196        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
197        if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
198            let project_path = project
199                .find_project_path(path, cx)
200                .filter(|path| path.worktree_id == worktree_id)?;
201            let relative_path: Arc<Path> = project_path.path.as_std_path().into();
202            Some((project_path, relative_path))
203        }) {
204            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
205                let buffer = project
206                    .update(cx, |project, cx| {
207                        project.open_buffer(project_path.clone(), cx)
208                    })
209                    .await?;
210                let diff = git_store
211                    .update(cx, |git_store, cx| {
212                        git_store.open_uncommitted_diff(buffer.clone(), cx)
213                    })
214                    .await?;
215                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
216                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
217            }
218        }
219    }
220    Ok(snapshots_by_path)
221}
222
223fn compute_uncommitted_diff(
224    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
225) -> String {
226    let mut uncommitted_diff = String::new();
227    for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
228        if let Some(head_text) = &diff_snapshot.base_text_string() {
229            let file_diff = language::unified_diff(head_text, &before_text.text());
230            if !file_diff.is_empty() {
231                let path_str = relative_path.to_string_lossy();
232                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
233                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
234                uncommitted_diff.push_str(&file_diff);
235                if !uncommitted_diff.ends_with('\n') {
236                    uncommitted_diff.push('\n');
237                }
238            }
239        }
240    }
241    uncommitted_diff
242}
243
244fn generate_timestamp_name() -> String {
245    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
246    match format {
247        Ok(format) => {
248            let now = time::OffsetDateTime::now_local()
249                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
250            now.format(&format)
251                .unwrap_or_else(|_| "unknown-time".to_string())
252        }
253        Err(_) => "unknown-time".to_string(),
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use crate::EditPredictionStore;
261    use client::RefreshLlmTokenListener;
262    use client::{Client, UserStore};
263    use clock::FakeSystemClock;
264    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
265    use indoc::indoc;
266    use language::{Anchor, Point};
267    use project::{FakeFs, Project};
268    use serde_json::json;
269    use settings::SettingsStore;
270    use std::path::Path;
271
272    #[gpui::test]
273    async fn test_capture_example(cx: &mut TestAppContext) {
274        init_test(cx);
275        let fs = FakeFs::new(cx.executor());
276
277        let committed_contents = indoc! {"
278            fn main() {
279                one();
280                two();
281                three();
282                four();
283                five();
284                six();
285                seven();
286                eight();
287                nine();
288            }
289        "};
290
291        let disk_contents = indoc! {"
292            fn main() {
293                // comment 1
294                one();
295                two();
296                three();
297                four();
298                five();
299                six();
300                seven();
301                eight();
302                // comment 2
303                nine();
304            }
305        "};
306
307        fs.insert_tree(
308            "/project",
309            json!({
310                ".git": {},
311                "src": {
312                    "main.rs": disk_contents,
313                }
314            }),
315        )
316        .await;
317
318        // Create an external file outside the main project
319        fs.insert_tree(
320            "/external",
321            json!({
322                "external.rs": "fn external() {}\n",
323            }),
324        )
325        .await;
326
327        fs.set_head_for_repo(
328            Path::new("/project/.git"),
329            &[("src/main.rs", committed_contents.to_string())],
330            "abc123def456",
331        );
332        fs.set_remote_for_repo(
333            Path::new("/project/.git"),
334            "origin",
335            "https://github.com/test/repo.git",
336        );
337
338        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
339
340        let buffer = project
341            .update(cx, |project, cx| {
342                project.open_local_buffer("/project/src/main.rs", cx)
343            })
344            .await
345            .unwrap();
346
347        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
348        ep_store.update(cx, |ep_store, cx| {
349            ep_store.register_buffer(&buffer, &project, cx)
350        });
351        cx.run_until_parked();
352
353        buffer.update(cx, |buffer, cx| {
354            let point = Point::new(6, 0);
355            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
356            let point = Point::new(4, 0);
357            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
358
359            pretty_assertions::assert_eq!(
360                buffer.text(),
361                indoc! {"
362                    fn main() {
363                        // comment 1
364                        one();
365                        two();
366                        // comment 4
367                        three();
368                        four();
369                        // comment 3
370                        five();
371                        six();
372                        seven();
373                        eight();
374                        // comment 2
375                        nine();
376                    }
377                "}
378            );
379        });
380        cx.run_until_parked();
381
382        // Open and edit an external file (outside the main project's worktree)
383        let external_buffer = project
384            .update(cx, |project, cx| {
385                project.open_local_buffer("/external/external.rs", cx)
386            })
387            .await
388            .unwrap();
389        ep_store.update(cx, |ep_store, cx| {
390            ep_store.register_buffer(&external_buffer, &project, cx)
391        });
392        cx.run_until_parked();
393        external_buffer.update(cx, |buffer, cx| {
394            let point = Point::new(0, 0);
395            buffer.edit([(point..point, "// external edit\n")], None, cx);
396        });
397        cx.run_until_parked();
398
399        // Verify the external edit was recorded in events
400        let events = ep_store.update(cx, |store, cx| store.edit_history_for_project(&project, cx));
401        assert!(
402            matches!(
403                events
404                    .last()
405                    .unwrap()
406                    .event
407                    .as_ref(),
408                zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
409            ),
410            "external file edit should be in events"
411        );
412
413        let mut example = cx
414            .update(|cx| {
415                capture_example(
416                    project.clone(),
417                    buffer.clone(),
418                    Anchor::min_for_buffer(buffer.read(cx).remote_id()),
419                    events,
420                    true,
421                    cx,
422                )
423                .unwrap()
424            })
425            .await
426            .unwrap();
427        example.name = "test".to_string();
428
429        pretty_assertions::assert_eq!(
430            example,
431            ExampleSpec {
432                name: "test".to_string(),
433                repository_url: "https://github.com/test/repo.git".to_string(),
434                revision: "abc123def456".to_string(),
435                tags: Vec::new(),
436                reasoning: None,
437                uncommitted_diff: indoc! {"
438                    --- a/src/main.rs
439                    +++ b/src/main.rs
440                    @@ -1,4 +1,5 @@
441                     fn main() {
442                    +    // comment 1
443                         one();
444                         two();
445                         three();
446                    @@ -7,5 +8,6 @@
447                         six();
448                         seven();
449                         eight();
450                    +    // comment 2
451                         nine();
452                     }
453                "}
454                .to_string(),
455                cursor_path: Path::new("src/main.rs").into(),
456                cursor_position: indoc! {"
457                    fn main() {
458                    ^[CURSOR_POSITION]
459                        // comment 1
460                        one();
461                        two();
462                        // comment 4
463                        three();
464                        four();
465                        // comment 3
466                        five();
467                        six();
468                        seven();
469                        eight();
470                        // comment 2
471                        nine();
472                    }
473                "}
474                .to_string(),
475                edit_history: indoc! {"
476                    --- a/src/main.rs
477                    +++ b/src/main.rs
478                    @@ -2,8 +2,10 @@
479                         // comment 1
480                         one();
481                         two();
482                    +    // comment 4
483                         three();
484                         four();
485                    +    // comment 3
486                         five();
487                         six();
488                         seven();
489                "}
490                .to_string(),
491                expected_patches: vec![
492                    indoc! {"
493                        --- a/src/main.rs
494                        +++ b/src/main.rs
495                        @@ -1,16 +1,16 @@
496                         fn main() {
497                             // comment 1
498                             one();
499                             two();
500                             // comment 4
501                             three();
502                             four();
503                             // comment 3
504                             five();
505                             six();
506                             seven();
507                             eight();
508                             // comment 2
509                             nine();
510                         }
511                    "}
512                    .to_string()
513                ],
514                rejected_patch: Some(
515                    indoc! {"
516                        --- a/src/main.rs
517                        +++ b/src/main.rs
518                        @@ -1,16 +1,16 @@
519                         fn main() {
520                             // comment 1
521                             one();
522                             two();
523                             // comment 4
524                             three();
525                             four();
526                             // comment 3
527                             five();
528                             six();
529                             seven();
530                             eight();
531                             // comment 2
532                             nine();
533                         }
534                    "}
535                    .to_string()
536                ),
537                telemetry: None,
538                human_feedback: Vec::new(),
539                rating: None,
540            }
541        );
542    }
543
544    fn init_test(cx: &mut TestAppContext) {
545        cx.update(|cx| {
546            let settings_store = SettingsStore::test(cx);
547            cx.set_global(settings_store);
548            zlog::init_test();
549            let http_client = FakeHttpClient::with_404_response();
550            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
551            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
552            language_model::init(cx);
553            RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
554            EditPredictionStore::global(&client, &user_store, cx);
555        })
556    }
557}