capture_example.rs

  1use crate::{
  2    StoredEvent, cursor_excerpt::editable_and_context_ranges_for_cursor_position,
  3    example_spec::ExampleSpec,
  4};
  5use anyhow::Result;
  6use buffer_diff::BufferDiffSnapshot;
  7use collections::HashMap;
  8use gpui::{App, Entity, Task};
  9use language::{Buffer, ToPoint as _};
 10use project::{Project, WorktreeId};
 11use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
 12use text::{BufferSnapshot as TextBufferSnapshot, Point};
 13
 14pub fn capture_example(
 15    project: Entity<Project>,
 16    buffer: Entity<Buffer>,
 17    cursor_anchor: language::Anchor,
 18    mut events: Vec<StoredEvent>,
 19    populate_expected_patch: bool,
 20    cx: &mut App,
 21) -> Option<Task<Result<ExampleSpec>>> {
 22    let snapshot = buffer.read(cx).snapshot();
 23    let file = snapshot.file()?;
 24    let worktree_id = file.worktree_id(cx);
 25    let repository = project.read(cx).active_repository(cx)?;
 26    let repository_snapshot = repository.read(cx).snapshot();
 27    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 28    let root_name = worktree.read(cx).root_name_str().to_owned();
 29    let cursor_path: Arc<Path> = file.path().as_std_path().into();
 30    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 31        return None;
 32    }
 33
 34    let repository_url = repository_snapshot
 35        .remote_origin_url
 36        .clone()
 37        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 38    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 39
 40    let git_store = project.read(cx).git_store().clone();
 41
 42    Some(cx.spawn(async move |mut cx| {
 43        let snapshots_by_path =
 44            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 45
 46        events.retain(|stored_event| {
 47            let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
 48            let relative_path = strip_root_name(path, &root_name);
 49            snapshots_by_path.contains_key(relative_path)
 50        });
 51
 52        let line_comment_prefix = snapshot
 53            .language()
 54            .and_then(|lang| lang.config().line_comments.first())
 55            .map(|s| s.to_string())
 56            .unwrap_or_default();
 57
 58        let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
 59            .background_executor()
 60            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 61            .await;
 62        let uncommitted_diff = cx
 63            .background_executor()
 64            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 65            .await;
 66
 67        let mut edit_history = String::new();
 68        for stored_event in &events {
 69            write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
 70            if !edit_history.ends_with('\n') {
 71                edit_history.push('\n');
 72            }
 73        }
 74
 75        // Initialize an empty patch with context lines, to make it easy
 76        // to write the expected patch by hand.
 77        let mut expected_patches = Vec::new();
 78        let mut rejected_patch = None;
 79        if populate_expected_patch {
 80            let mut empty_patch = String::new();
 81            let start_row = cursor_excerpt_range.start.row + 1;
 82            let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
 83            writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
 84            writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
 85            writeln!(
 86                &mut empty_patch,
 87                "@@ -{},{} +{},{} @@",
 88                start_row, row_count, start_row, row_count,
 89            )
 90            .ok();
 91            for line in cursor_excerpt.lines() {
 92                writeln!(&mut empty_patch, " {}", line).ok();
 93            }
 94
 95            expected_patches.push(empty_patch.clone());
 96            rejected_patch = Some(empty_patch);
 97        }
 98
 99        let mut spec = ExampleSpec {
100            name: generate_timestamp_name(),
101            repository_url,
102            revision,
103            tags: Vec::new(),
104            reasoning: None,
105            uncommitted_diff,
106            cursor_path,
107            cursor_position: String::new(),
108            edit_history,
109            expected_patches,
110            rejected_patch,
111            telemetry: None,
112            human_feedback: Vec::new(),
113            rating: None,
114        };
115        spec.set_cursor_excerpt(
116            &cursor_excerpt,
117            cursor_offset_in_excerpt,
118            &line_comment_prefix,
119        );
120        Ok(spec)
121    }))
122}
123
124fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
125    path.strip_prefix(root_name).unwrap_or(path)
126}
127
128fn write_event_with_relative_paths(
129    output: &mut String,
130    event: &zeta_prompt::Event,
131    root_name: &str,
132) {
133    fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
134        for component in strip_root_name(path, root_name).components() {
135            output.push('/');
136            write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
137        }
138    }
139
140    let zeta_prompt::Event::BufferChange {
141        path,
142        old_path,
143        diff,
144        ..
145    } = event;
146
147    output.push_str("--- a");
148    write_relative_path(output, old_path.as_ref(), root_name);
149    output.push_str("\n+++ b");
150    write_relative_path(output, path.as_ref(), root_name);
151    output.push('\n');
152    output.push_str(diff);
153}
154
155fn compute_cursor_excerpt(
156    snapshot: &language::BufferSnapshot,
157    cursor_anchor: language::Anchor,
158) -> (String, usize, Range<Point>) {
159    use text::ToOffset as _;
160
161    let cursor_point = cursor_anchor.to_point(snapshot);
162    let (_editable_range, context_range) =
163        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
164    let context_start_offset = context_range.start.to_offset(snapshot);
165    let cursor_offset = cursor_anchor.to_offset(snapshot);
166    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
167    let excerpt = snapshot
168        .text_for_range(context_range.clone())
169        .collect::<String>();
170    (excerpt, cursor_offset_in_excerpt, context_range)
171}
172
173async fn collect_snapshots(
174    project: &Entity<Project>,
175    git_store: &Entity<project::git_store::GitStore>,
176    worktree_id: WorktreeId,
177    events: &[StoredEvent],
178    cx: &mut gpui::AsyncApp,
179) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
180    let mut snapshots_by_path = HashMap::default();
181    for stored_event in events {
182        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
183        if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
184            let project_path = project
185                .find_project_path(path, cx)
186                .filter(|path| path.worktree_id == worktree_id)?;
187            let relative_path: Arc<Path> = project_path.path.as_std_path().into();
188            Some((project_path, relative_path))
189        }) {
190            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
191                let buffer = project
192                    .update(cx, |project, cx| {
193                        project.open_buffer(project_path.clone(), cx)
194                    })
195                    .await?;
196                let diff = git_store
197                    .update(cx, |git_store, cx| {
198                        git_store.open_uncommitted_diff(buffer.clone(), cx)
199                    })
200                    .await?;
201                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
202                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
203            }
204        }
205    }
206    Ok(snapshots_by_path)
207}
208
209fn compute_uncommitted_diff(
210    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
211) -> String {
212    let mut uncommitted_diff = String::new();
213    for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
214        if let Some(head_text) = &diff_snapshot.base_text_string() {
215            let file_diff = language::unified_diff(head_text, &before_text.text());
216            if !file_diff.is_empty() {
217                let path_str = relative_path.to_string_lossy();
218                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
219                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
220                uncommitted_diff.push_str(&file_diff);
221                if !uncommitted_diff.ends_with('\n') {
222                    uncommitted_diff.push('\n');
223                }
224            }
225        }
226    }
227    uncommitted_diff
228}
229
230fn generate_timestamp_name() -> String {
231    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
232    match format {
233        Ok(format) => {
234            let now = time::OffsetDateTime::now_local()
235                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
236            now.format(&format)
237                .unwrap_or_else(|_| "unknown-time".to_string())
238        }
239        Err(_) => "unknown-time".to_string(),
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use crate::EditPredictionStore;
247    use client::{Client, UserStore};
248    use clock::FakeSystemClock;
249    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
250    use indoc::indoc;
251    use language::{Anchor, Point};
252    use project::{FakeFs, Project};
253    use serde_json::json;
254    use settings::SettingsStore;
255    use std::path::Path;
256
257    #[gpui::test]
258    async fn test_capture_example(cx: &mut TestAppContext) {
259        init_test(cx);
260        let fs = FakeFs::new(cx.executor());
261
262        let committed_contents = indoc! {"
263            fn main() {
264                one();
265                two();
266                three();
267                four();
268                five();
269                six();
270                seven();
271                eight();
272                nine();
273            }
274        "};
275
276        let disk_contents = indoc! {"
277            fn main() {
278                // comment 1
279                one();
280                two();
281                three();
282                four();
283                five();
284                six();
285                seven();
286                eight();
287                // comment 2
288                nine();
289            }
290        "};
291
292        fs.insert_tree(
293            "/project",
294            json!({
295                ".git": {},
296                "src": {
297                    "main.rs": disk_contents,
298                }
299            }),
300        )
301        .await;
302
303        // Create an external file outside the main project
304        fs.insert_tree(
305            "/external",
306            json!({
307                "external.rs": "fn external() {}\n",
308            }),
309        )
310        .await;
311
312        fs.set_head_for_repo(
313            Path::new("/project/.git"),
314            &[("src/main.rs", committed_contents.to_string())],
315            "abc123def456",
316        );
317        fs.set_remote_for_repo(
318            Path::new("/project/.git"),
319            "origin",
320            "https://github.com/test/repo.git",
321        );
322
323        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
324
325        let buffer = project
326            .update(cx, |project, cx| {
327                project.open_local_buffer("/project/src/main.rs", cx)
328            })
329            .await
330            .unwrap();
331
332        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
333        ep_store.update(cx, |ep_store, cx| {
334            ep_store.register_buffer(&buffer, &project, cx)
335        });
336        cx.run_until_parked();
337
338        buffer.update(cx, |buffer, cx| {
339            let point = Point::new(6, 0);
340            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
341            let point = Point::new(4, 0);
342            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
343
344            pretty_assertions::assert_eq!(
345                buffer.text(),
346                indoc! {"
347                    fn main() {
348                        // comment 1
349                        one();
350                        two();
351                        // comment 4
352                        three();
353                        four();
354                        // comment 3
355                        five();
356                        six();
357                        seven();
358                        eight();
359                        // comment 2
360                        nine();
361                    }
362                "}
363            );
364        });
365        cx.run_until_parked();
366
367        // Open and edit an external file (outside the main project's worktree)
368        let external_buffer = project
369            .update(cx, |project, cx| {
370                project.open_local_buffer("/external/external.rs", cx)
371            })
372            .await
373            .unwrap();
374        ep_store.update(cx, |ep_store, cx| {
375            ep_store.register_buffer(&external_buffer, &project, cx)
376        });
377        cx.run_until_parked();
378        external_buffer.update(cx, |buffer, cx| {
379            let point = Point::new(0, 0);
380            buffer.edit([(point..point, "// external edit\n")], None, cx);
381        });
382        cx.run_until_parked();
383
384        // Verify the external edit was recorded in events
385        let events = ep_store.update(cx, |store, cx| store.edit_history_for_project(&project, cx));
386        assert!(
387            matches!(
388                events
389                    .last()
390                    .unwrap()
391                    .event
392                    .as_ref(),
393                zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
394            ),
395            "external file edit should be in events"
396        );
397
398        let mut example = cx
399            .update(|cx| {
400                capture_example(
401                    project.clone(),
402                    buffer.clone(),
403                    Anchor::MIN,
404                    events,
405                    true,
406                    cx,
407                )
408                .unwrap()
409            })
410            .await
411            .unwrap();
412        example.name = "test".to_string();
413
414        pretty_assertions::assert_eq!(
415            example,
416            ExampleSpec {
417                name: "test".to_string(),
418                repository_url: "https://github.com/test/repo.git".to_string(),
419                revision: "abc123def456".to_string(),
420                tags: Vec::new(),
421                reasoning: None,
422                uncommitted_diff: indoc! {"
423                    --- a/src/main.rs
424                    +++ b/src/main.rs
425                    @@ -1,4 +1,5 @@
426                     fn main() {
427                    +    // comment 1
428                         one();
429                         two();
430                         three();
431                    @@ -7,5 +8,6 @@
432                         six();
433                         seven();
434                         eight();
435                    +    // comment 2
436                         nine();
437                     }
438                "}
439                .to_string(),
440                cursor_path: Path::new("src/main.rs").into(),
441                cursor_position: indoc! {"
442                    fn main() {
443                    ^[CURSOR_POSITION]
444                        // comment 1
445                        one();
446                        two();
447                        // comment 4
448                        three();
449                        four();
450                        // comment 3
451                        five();
452                        six();
453                        seven();
454                        eight();
455                        // comment 2
456                        nine();
457                    }
458                "}
459                .to_string(),
460                edit_history: indoc! {"
461                    --- a/src/main.rs
462                    +++ b/src/main.rs
463                    @@ -2,8 +2,10 @@
464                         // comment 1
465                         one();
466                         two();
467                    +    // comment 4
468                         three();
469                         four();
470                    +    // comment 3
471                         five();
472                         six();
473                         seven();
474                "}
475                .to_string(),
476                expected_patches: vec![
477                    indoc! {"
478                        --- a/src/main.rs
479                        +++ b/src/main.rs
480                        @@ -1,16 +1,16 @@
481                         fn main() {
482                             // comment 1
483                             one();
484                             two();
485                             // comment 4
486                             three();
487                             four();
488                             // comment 3
489                             five();
490                             six();
491                             seven();
492                             eight();
493                             // comment 2
494                             nine();
495                         }
496                    "}
497                    .to_string()
498                ],
499                rejected_patch: Some(
500                    indoc! {"
501                        --- a/src/main.rs
502                        +++ b/src/main.rs
503                        @@ -1,16 +1,16 @@
504                         fn main() {
505                             // comment 1
506                             one();
507                             two();
508                             // comment 4
509                             three();
510                             four();
511                             // comment 3
512                             five();
513                             six();
514                             seven();
515                             eight();
516                             // comment 2
517                             nine();
518                         }
519                    "}
520                    .to_string()
521                ),
522                telemetry: None,
523                human_feedback: Vec::new(),
524                rating: None,
525            }
526        );
527    }
528
529    fn init_test(cx: &mut TestAppContext) {
530        cx.update(|cx| {
531            let settings_store = SettingsStore::test(cx);
532            cx.set_global(settings_store);
533            zlog::init_test();
534            let http_client = FakeHttpClient::with_404_response();
535            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
536            language_model::init(client.clone(), cx);
537            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
538            EditPredictionStore::global(&client, &user_store, cx);
539        })
540    }
541}