capture_example.rs

  1use crate::{StoredEvent, example_spec::ExampleSpec};
  2use anyhow::Result;
  3use buffer_diff::BufferDiffSnapshot;
  4use collections::HashMap;
  5use gpui::{App, Entity, Task};
  6use language::Buffer;
  7use project::{Project, WorktreeId};
  8use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
  9use text::{BufferSnapshot as TextBufferSnapshot, Point};
 10
 11pub fn capture_example(
 12    project: Entity<Project>,
 13    buffer: Entity<Buffer>,
 14    cursor_anchor: language::Anchor,
 15    mut events: Vec<StoredEvent>,
 16    populate_expected_patch: bool,
 17    cx: &mut App,
 18) -> Option<Task<Result<ExampleSpec>>> {
 19    let snapshot = buffer.read(cx).snapshot();
 20    let file = snapshot.file()?;
 21    let worktree_id = file.worktree_id(cx);
 22    let repository = project.read(cx).active_repository(cx)?;
 23    let repository_snapshot = repository.read(cx).snapshot();
 24    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 25    let root_name = worktree.read(cx).root_name_str().to_owned();
 26    let cursor_path: Arc<Path> = file.path().as_std_path().into();
 27    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 28        return None;
 29    }
 30
 31    let repository_url = repository_snapshot
 32        .remote_origin_url
 33        .clone()
 34        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 35    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 36
 37    let git_store = project.read(cx).git_store().clone();
 38
 39    Some(cx.spawn(async move |mut cx| {
 40        let snapshots_by_path =
 41            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 42
 43        events.retain(|stored_event| {
 44            let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
 45            let relative_path = strip_root_name(path, &root_name);
 46            snapshots_by_path.contains_key(relative_path)
 47        });
 48
 49        let line_comment_prefix = snapshot
 50            .language()
 51            .and_then(|lang| lang.config().line_comments.first())
 52            .map(|s| s.to_string())
 53            .unwrap_or_default();
 54
 55        let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
 56            .background_executor()
 57            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 58            .await;
 59        let uncommitted_diff = cx
 60            .background_executor()
 61            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 62            .await;
 63
 64        let mut edit_history = String::new();
 65        for stored_event in &events {
 66            write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
 67            if !edit_history.ends_with('\n') {
 68                edit_history.push('\n');
 69            }
 70        }
 71
 72        // Initialize an empty patch with context lines, to make it easy
 73        // to write the expected patch by hand.
 74        let mut expected_patches = Vec::new();
 75        let mut rejected_patch = None;
 76        if populate_expected_patch {
 77            let mut empty_patch = String::new();
 78            let start_row = cursor_excerpt_range.start.row + 1;
 79            let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
 80            writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
 81            writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
 82            writeln!(
 83                &mut empty_patch,
 84                "@@ -{},{} +{},{} @@",
 85                start_row, row_count, start_row, row_count,
 86            )
 87            .ok();
 88            for line in cursor_excerpt.lines() {
 89                writeln!(&mut empty_patch, " {}", line).ok();
 90            }
 91
 92            expected_patches.push(empty_patch.clone());
 93            rejected_patch = Some(empty_patch);
 94        }
 95
 96        let mut spec = ExampleSpec {
 97            name: generate_timestamp_name(),
 98            repository_url,
 99            revision,
100            tags: Vec::new(),
101            reasoning: None,
102            uncommitted_diff,
103            cursor_path,
104            cursor_position: String::new(),
105            edit_history,
106            expected_patches,
107            rejected_patch,
108            telemetry: None,
109            human_feedback: Vec::new(),
110            rating: None,
111        };
112        spec.set_cursor_excerpt(
113            &cursor_excerpt,
114            cursor_offset_in_excerpt,
115            &line_comment_prefix,
116        );
117        Ok(spec)
118    }))
119}
120
121fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
122    path.strip_prefix(root_name).unwrap_or(path)
123}
124
125fn write_event_with_relative_paths(
126    output: &mut String,
127    event: &zeta_prompt::Event,
128    root_name: &str,
129) {
130    fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
131        for component in strip_root_name(path, root_name).components() {
132            output.push('/');
133            write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
134        }
135    }
136
137    let zeta_prompt::Event::BufferChange {
138        path,
139        old_path,
140        diff,
141        ..
142    } = event;
143
144    output.push_str("--- a");
145    write_relative_path(output, old_path.as_ref(), root_name);
146    output.push_str("\n+++ b");
147    write_relative_path(output, path.as_ref(), root_name);
148    output.push('\n');
149    output.push_str(diff);
150}
151
152fn compute_cursor_excerpt(
153    snapshot: &language::BufferSnapshot,
154    cursor_anchor: language::Anchor,
155) -> (String, usize, Range<Point>) {
156    use text::ToOffset as _;
157    use text::ToPoint as _;
158
159    let cursor_offset = cursor_anchor.to_offset(snapshot);
160    let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
161        crate::cursor_excerpt::compute_cursor_excerpt(snapshot, cursor_offset);
162    let syntax_ranges = crate::cursor_excerpt::compute_syntax_ranges(
163        snapshot,
164        cursor_offset,
165        &excerpt_offset_range,
166    );
167    let excerpt_text: String = snapshot.text_for_range(excerpt_point_range).collect();
168    let (_, context_range) = zeta_prompt::compute_editable_and_context_ranges(
169        &excerpt_text,
170        cursor_offset_in_excerpt,
171        &syntax_ranges,
172        100,
173        50,
174    );
175    let context_text = excerpt_text[context_range.clone()].to_string();
176    let cursor_in_context = cursor_offset_in_excerpt.saturating_sub(context_range.start);
177    let context_buffer_start =
178        (excerpt_offset_range.start + context_range.start).to_point(snapshot);
179    let context_buffer_end = (excerpt_offset_range.start + context_range.end).to_point(snapshot);
180    (
181        context_text,
182        cursor_in_context,
183        context_buffer_start..context_buffer_end,
184    )
185}
186
187async fn collect_snapshots(
188    project: &Entity<Project>,
189    git_store: &Entity<project::git_store::GitStore>,
190    worktree_id: WorktreeId,
191    events: &[StoredEvent],
192    cx: &mut gpui::AsyncApp,
193) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
194    let mut snapshots_by_path = HashMap::default();
195    for stored_event in events {
196        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
197        if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
198            let project_path = project
199                .find_project_path(path, cx)
200                .filter(|path| path.worktree_id == worktree_id)?;
201            let relative_path: Arc<Path> = project_path.path.as_std_path().into();
202            Some((project_path, relative_path))
203        }) {
204            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
205                let buffer = project
206                    .update(cx, |project, cx| {
207                        project.open_buffer(project_path.clone(), cx)
208                    })
209                    .await?;
210                let diff = git_store
211                    .update(cx, |git_store, cx| {
212                        git_store.open_uncommitted_diff(buffer.clone(), cx)
213                    })
214                    .await?;
215                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
216                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
217            }
218        }
219    }
220    Ok(snapshots_by_path)
221}
222
223fn compute_uncommitted_diff(
224    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
225) -> String {
226    let mut uncommitted_diff = String::new();
227    for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
228        if let Some(head_text) = &diff_snapshot.base_text_string() {
229            let file_diff = language::unified_diff(head_text, &before_text.text());
230            if !file_diff.is_empty() {
231                let path_str = relative_path.to_string_lossy();
232                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
233                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
234                uncommitted_diff.push_str(&file_diff);
235                if !uncommitted_diff.ends_with('\n') {
236                    uncommitted_diff.push('\n');
237                }
238            }
239        }
240    }
241    uncommitted_diff
242}
243
244fn generate_timestamp_name() -> String {
245    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
246    match format {
247        Ok(format) => {
248            let now = time::OffsetDateTime::now_local()
249                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
250            now.format(&format)
251                .unwrap_or_else(|_| "unknown-time".to_string())
252        }
253        Err(_) => "unknown-time".to_string(),
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use crate::EditPredictionStore;
261    use client::{Client, UserStore};
262    use clock::FakeSystemClock;
263    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
264    use indoc::indoc;
265    use language::{Anchor, Point};
266    use project::{FakeFs, Project};
267    use serde_json::json;
268    use settings::SettingsStore;
269    use std::path::Path;
270
271    #[gpui::test]
272    async fn test_capture_example(cx: &mut TestAppContext) {
273        init_test(cx);
274        let fs = FakeFs::new(cx.executor());
275
276        let committed_contents = indoc! {"
277            fn main() {
278                one();
279                two();
280                three();
281                four();
282                five();
283                six();
284                seven();
285                eight();
286                nine();
287            }
288        "};
289
290        let disk_contents = indoc! {"
291            fn main() {
292                // comment 1
293                one();
294                two();
295                three();
296                four();
297                five();
298                six();
299                seven();
300                eight();
301                // comment 2
302                nine();
303            }
304        "};
305
306        fs.insert_tree(
307            "/project",
308            json!({
309                ".git": {},
310                "src": {
311                    "main.rs": disk_contents,
312                }
313            }),
314        )
315        .await;
316
317        // Create an external file outside the main project
318        fs.insert_tree(
319            "/external",
320            json!({
321                "external.rs": "fn external() {}\n",
322            }),
323        )
324        .await;
325
326        fs.set_head_for_repo(
327            Path::new("/project/.git"),
328            &[("src/main.rs", committed_contents.to_string())],
329            "abc123def456",
330        );
331        fs.set_remote_for_repo(
332            Path::new("/project/.git"),
333            "origin",
334            "https://github.com/test/repo.git",
335        );
336
337        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
338
339        let buffer = project
340            .update(cx, |project, cx| {
341                project.open_local_buffer("/project/src/main.rs", cx)
342            })
343            .await
344            .unwrap();
345
346        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
347        ep_store.update(cx, |ep_store, cx| {
348            ep_store.register_buffer(&buffer, &project, cx)
349        });
350        cx.run_until_parked();
351
352        buffer.update(cx, |buffer, cx| {
353            let point = Point::new(6, 0);
354            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
355            let point = Point::new(4, 0);
356            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
357
358            pretty_assertions::assert_eq!(
359                buffer.text(),
360                indoc! {"
361                    fn main() {
362                        // comment 1
363                        one();
364                        two();
365                        // comment 4
366                        three();
367                        four();
368                        // comment 3
369                        five();
370                        six();
371                        seven();
372                        eight();
373                        // comment 2
374                        nine();
375                    }
376                "}
377            );
378        });
379        cx.run_until_parked();
380
381        // Open and edit an external file (outside the main project's worktree)
382        let external_buffer = project
383            .update(cx, |project, cx| {
384                project.open_local_buffer("/external/external.rs", cx)
385            })
386            .await
387            .unwrap();
388        ep_store.update(cx, |ep_store, cx| {
389            ep_store.register_buffer(&external_buffer, &project, cx)
390        });
391        cx.run_until_parked();
392        external_buffer.update(cx, |buffer, cx| {
393            let point = Point::new(0, 0);
394            buffer.edit([(point..point, "// external edit\n")], None, cx);
395        });
396        cx.run_until_parked();
397
398        // Verify the external edit was recorded in events
399        let events = ep_store.update(cx, |store, cx| store.edit_history_for_project(&project, cx));
400        assert!(
401            matches!(
402                events
403                    .last()
404                    .unwrap()
405                    .event
406                    .as_ref(),
407                zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
408            ),
409            "external file edit should be in events"
410        );
411
412        let mut example = cx
413            .update(|cx| {
414                capture_example(
415                    project.clone(),
416                    buffer.clone(),
417                    Anchor::MIN,
418                    events,
419                    true,
420                    cx,
421                )
422                .unwrap()
423            })
424            .await
425            .unwrap();
426        example.name = "test".to_string();
427
428        pretty_assertions::assert_eq!(
429            example,
430            ExampleSpec {
431                name: "test".to_string(),
432                repository_url: "https://github.com/test/repo.git".to_string(),
433                revision: "abc123def456".to_string(),
434                tags: Vec::new(),
435                reasoning: None,
436                uncommitted_diff: indoc! {"
437                    --- a/src/main.rs
438                    +++ b/src/main.rs
439                    @@ -1,4 +1,5 @@
440                     fn main() {
441                    +    // comment 1
442                         one();
443                         two();
444                         three();
445                    @@ -7,5 +8,6 @@
446                         six();
447                         seven();
448                         eight();
449                    +    // comment 2
450                         nine();
451                     }
452                "}
453                .to_string(),
454                cursor_path: Path::new("src/main.rs").into(),
455                cursor_position: indoc! {"
456                    fn main() {
457                    ^[CURSOR_POSITION]
458                        // comment 1
459                        one();
460                        two();
461                        // comment 4
462                        three();
463                        four();
464                        // comment 3
465                        five();
466                        six();
467                        seven();
468                        eight();
469                        // comment 2
470                        nine();
471                    }
472                "}
473                .to_string(),
474                edit_history: indoc! {"
475                    --- a/src/main.rs
476                    +++ b/src/main.rs
477                    @@ -2,8 +2,10 @@
478                         // comment 1
479                         one();
480                         two();
481                    +    // comment 4
482                         three();
483                         four();
484                    +    // comment 3
485                         five();
486                         six();
487                         seven();
488                "}
489                .to_string(),
490                expected_patches: vec![
491                    indoc! {"
492                        --- a/src/main.rs
493                        +++ b/src/main.rs
494                        @@ -1,16 +1,16 @@
495                         fn main() {
496                             // comment 1
497                             one();
498                             two();
499                             // comment 4
500                             three();
501                             four();
502                             // comment 3
503                             five();
504                             six();
505                             seven();
506                             eight();
507                             // comment 2
508                             nine();
509                         }
510                    "}
511                    .to_string()
512                ],
513                rejected_patch: Some(
514                    indoc! {"
515                        --- a/src/main.rs
516                        +++ b/src/main.rs
517                        @@ -1,16 +1,16 @@
518                         fn main() {
519                             // comment 1
520                             one();
521                             two();
522                             // comment 4
523                             three();
524                             four();
525                             // comment 3
526                             five();
527                             six();
528                             seven();
529                             eight();
530                             // comment 2
531                             nine();
532                         }
533                    "}
534                    .to_string()
535                ),
536                telemetry: None,
537                human_feedback: Vec::new(),
538                rating: None,
539            }
540        );
541    }
542
543    fn init_test(cx: &mut TestAppContext) {
544        cx.update(|cx| {
545            let settings_store = SettingsStore::test(cx);
546            cx.set_global(settings_store);
547            zlog::init_test();
548            let http_client = FakeHttpClient::with_404_response();
549            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
550            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
551            language_model::init(user_store.clone(), client.clone(), cx);
552            EditPredictionStore::global(&client, &user_store, cx);
553        })
554    }
555}