capture_example.rs

  1use crate::{
  2    StoredEvent,
  3    cursor_excerpt::editable_and_context_ranges_for_cursor_position,
  4    example_spec::{
  5        CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile,
  6        ExampleSpec, MAX_CURSOR_FILE_SIZE,
  7    },
  8};
  9use anyhow::Result;
 10use buffer_diff::BufferDiffSnapshot;
 11use collections::HashMap;
 12use gpui::{App, Entity, Task};
 13use language::{Buffer, ToPoint as _};
 14use project::{Project, WorktreeId};
 15use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
 16use text::{BufferSnapshot as TextBufferSnapshot, Point, ToOffset as _};
 17
 18pub(crate) const ZETA2_TESTING_RATE_PER_10K_PREDICTION: u16 = 500;
 19
 20pub fn capture_example(
 21    project: Entity<Project>,
 22    buffer: Entity<Buffer>,
 23    cursor_anchor: language::Anchor,
 24    mut events: Vec<StoredEvent>,
 25    related_files: Vec<zeta_prompt::RelatedFile>,
 26    populate_expected_patch: bool,
 27    cx: &mut App,
 28) -> Option<Task<Result<ExampleSpec>>> {
 29    let snapshot = buffer.read(cx).snapshot();
 30    let file = snapshot.file()?;
 31    let worktree_id = file.worktree_id(cx);
 32    let repository = project.read(cx).active_repository(cx)?;
 33    let repository_snapshot = repository.read(cx).snapshot();
 34    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 35    let root_name = worktree.read(cx).root_name_str().to_owned();
 36    let cursor_path: Arc<Path> = file.path().as_std_path().into();
 37    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 38        return None;
 39    }
 40
 41    let repository_url = repository_snapshot
 42        .remote_origin_url
 43        .clone()
 44        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 45    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 46
 47    let git_store = project.read(cx).git_store().clone();
 48
 49    Some(cx.spawn(async move |mut cx| {
 50        let snapshots_by_path =
 51            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 52
 53        events.retain(|stored_event| {
 54            let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
 55            let relative_path = strip_root_name(path, &root_name);
 56            snapshots_by_path.contains_key(relative_path)
 57        });
 58
 59        let line_comment_prefix = snapshot
 60            .language()
 61            .and_then(|lang| lang.config().line_comments.first())
 62            .map(|s| s.to_string())
 63            .unwrap_or_default();
 64
 65        let full_cursor_offset = cursor_anchor.to_offset(&snapshot);
 66        let cursor_point = cursor_anchor.to_point(&snapshot);
 67        let cursor_file_content = if snapshot.len() <= MAX_CURSOR_FILE_SIZE {
 68            Some(snapshot.text())
 69        } else {
 70            None
 71        };
 72
 73        let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
 74            .background_executor()
 75            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 76            .await;
 77        let uncommitted_diff = cx
 78            .background_executor()
 79            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 80            .await;
 81
 82        let mut edit_history = String::new();
 83        for stored_event in &events {
 84            write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
 85            if !edit_history.ends_with('\n') {
 86                edit_history.push('\n');
 87            }
 88        }
 89
 90        // Initialize an empty patch with context lines, to make it easy
 91        // to write the expected patch by hand.
 92        let mut expected_patches = Vec::new();
 93        let mut rejected_patch = None;
 94        if populate_expected_patch {
 95            let mut empty_patch = String::new();
 96            let start_row = cursor_excerpt_range.start.row + 1;
 97            let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
 98            writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
 99            writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
100            writeln!(
101                &mut empty_patch,
102                "@@ -{},{} +{},{} @@",
103                start_row, row_count, start_row, row_count,
104            )
105            .ok();
106            for line in cursor_excerpt.lines() {
107                writeln!(&mut empty_patch, " {}", line).ok();
108            }
109
110            expected_patches.push(empty_patch.clone());
111            rejected_patch = Some(empty_patch);
112        }
113
114        let prompt_input = cursor_file_content.map(|content| {
115            let captured_events: Vec<CapturedEvent> = events
116                .iter()
117                .map(|stored_event| {
118                    let zeta_prompt::Event::BufferChange {
119                        path,
120                        old_path,
121                        diff,
122                        predicted,
123                        in_open_source_repo,
124                    } = stored_event.event.as_ref();
125                    CapturedEvent {
126                        path: strip_root_name(path, &root_name).into(),
127                        old_path: strip_root_name(old_path, &root_name).into(),
128                        diff: diff.clone(),
129                        predicted: *predicted,
130                        in_open_source_repo: *in_open_source_repo,
131                    }
132                })
133                .collect();
134
135            let captured_related_files: Vec<CapturedRelatedFile> = related_files
136                .iter()
137                .map(|rf| CapturedRelatedFile {
138                    path: strip_root_name(&rf.path, &root_name).into(),
139                    max_row: rf.max_row,
140                    excerpts: rf
141                        .excerpts
142                        .iter()
143                        .map(|e| CapturedRelatedExcerpt {
144                            row_range: e.row_range.clone(),
145                            text: e.text.to_string(),
146                        })
147                        .collect(),
148                })
149                .collect();
150
151            CapturedPromptInput {
152                cursor_file_content: content,
153                cursor_offset: full_cursor_offset,
154                cursor_row: cursor_point.row,
155                cursor_column: cursor_point.column,
156                excerpt_start_row: Some(0),
157                events: captured_events,
158                related_files: captured_related_files,
159            }
160        });
161
162        let mut spec = ExampleSpec {
163            name: generate_timestamp_name(),
164            repository_url,
165            revision,
166            tags: Vec::new(),
167            reasoning: None,
168            uncommitted_diff,
169            cursor_path,
170            cursor_position: String::new(),
171            edit_history,
172            expected_patches,
173            rejected_patch,
174            captured_prompt_input: prompt_input,
175            telemetry: None,
176            human_feedback: Vec::new(),
177            rating: None,
178        };
179        spec.set_cursor_excerpt(
180            &cursor_excerpt,
181            cursor_offset_in_excerpt,
182            &line_comment_prefix,
183        );
184        Ok(spec)
185    }))
186}
187
188fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
189    path.strip_prefix(root_name).unwrap_or(path)
190}
191
192fn write_event_with_relative_paths(
193    output: &mut String,
194    event: &zeta_prompt::Event,
195    root_name: &str,
196) {
197    fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
198        for component in strip_root_name(path, root_name).components() {
199            output.push('/');
200            write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
201        }
202    }
203
204    let zeta_prompt::Event::BufferChange {
205        path,
206        old_path,
207        diff,
208        ..
209    } = event;
210
211    output.push_str("--- a");
212    write_relative_path(output, old_path.as_ref(), root_name);
213    output.push_str("\n+++ b");
214    write_relative_path(output, path.as_ref(), root_name);
215    output.push('\n');
216    output.push_str(diff);
217}
218
219fn compute_cursor_excerpt(
220    snapshot: &language::BufferSnapshot,
221    cursor_anchor: language::Anchor,
222) -> (String, usize, Range<Point>) {
223    use text::ToOffset as _;
224
225    let cursor_point = cursor_anchor.to_point(snapshot);
226    let (_editable_range, context_range) =
227        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
228    let context_start_offset = context_range.start.to_offset(snapshot);
229    let cursor_offset = cursor_anchor.to_offset(snapshot);
230    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
231    let excerpt = snapshot
232        .text_for_range(context_range.clone())
233        .collect::<String>();
234    (excerpt, cursor_offset_in_excerpt, context_range)
235}
236
237async fn collect_snapshots(
238    project: &Entity<Project>,
239    git_store: &Entity<project::git_store::GitStore>,
240    worktree_id: WorktreeId,
241    events: &[StoredEvent],
242    cx: &mut gpui::AsyncApp,
243) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
244    let mut snapshots_by_path = HashMap::default();
245    for stored_event in events {
246        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
247        if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
248            let project_path = project
249                .find_project_path(path, cx)
250                .filter(|path| path.worktree_id == worktree_id)?;
251            let relative_path: Arc<Path> = project_path.path.as_std_path().into();
252            Some((project_path, relative_path))
253        }) {
254            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
255                let buffer = project
256                    .update(cx, |project, cx| {
257                        project.open_buffer(project_path.clone(), cx)
258                    })
259                    .await?;
260                let diff = git_store
261                    .update(cx, |git_store, cx| {
262                        git_store.open_uncommitted_diff(buffer.clone(), cx)
263                    })
264                    .await?;
265                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
266                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
267            }
268        }
269    }
270    Ok(snapshots_by_path)
271}
272
273fn compute_uncommitted_diff(
274    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
275) -> String {
276    let mut uncommitted_diff = String::new();
277    for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
278        if let Some(head_text) = &diff_snapshot.base_text_string() {
279            let file_diff = language::unified_diff(head_text, &before_text.text());
280            if !file_diff.is_empty() {
281                let path_str = relative_path.to_string_lossy();
282                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
283                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
284                uncommitted_diff.push_str(&file_diff);
285                if !uncommitted_diff.ends_with('\n') {
286                    uncommitted_diff.push('\n');
287                }
288            }
289        }
290    }
291    uncommitted_diff
292}
293
294fn generate_timestamp_name() -> String {
295    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
296    match format {
297        Ok(format) => {
298            let now = time::OffsetDateTime::now_local()
299                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
300            now.format(&format)
301                .unwrap_or_else(|_| "unknown-time".to_string())
302        }
303        Err(_) => "unknown-time".to_string(),
304    }
305}
306
307pub(crate) fn should_send_testing_zeta2_request() -> bool {
308    rand::random::<u16>() % 10_000 < ZETA2_TESTING_RATE_PER_10K_PREDICTION
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use crate::EditPredictionStore;
315    use client::{Client, UserStore};
316    use clock::FakeSystemClock;
317    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
318    use indoc::indoc;
319    use language::{Anchor, Point};
320    use project::{FakeFs, Project};
321    use serde_json::json;
322    use settings::SettingsStore;
323    use std::path::Path;
324
325    #[gpui::test]
326    async fn test_capture_example(cx: &mut TestAppContext) {
327        init_test(cx);
328        let fs = FakeFs::new(cx.executor());
329
330        let committed_contents = indoc! {"
331            fn main() {
332                one();
333                two();
334                three();
335                four();
336                five();
337                six();
338                seven();
339                eight();
340                nine();
341            }
342        "};
343
344        let disk_contents = indoc! {"
345            fn main() {
346                // comment 1
347                one();
348                two();
349                three();
350                four();
351                five();
352                six();
353                seven();
354                eight();
355                // comment 2
356                nine();
357            }
358        "};
359
360        fs.insert_tree(
361            "/project",
362            json!({
363                ".git": {},
364                "src": {
365                    "main.rs": disk_contents,
366                }
367            }),
368        )
369        .await;
370
371        // Create an external file outside the main project
372        fs.insert_tree(
373            "/external",
374            json!({
375                "external.rs": "fn external() {}\n",
376            }),
377        )
378        .await;
379
380        fs.set_head_for_repo(
381            Path::new("/project/.git"),
382            &[("src/main.rs", committed_contents.to_string())],
383            "abc123def456",
384        );
385        fs.set_remote_for_repo(
386            Path::new("/project/.git"),
387            "origin",
388            "https://github.com/test/repo.git",
389        );
390
391        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
392
393        let buffer = project
394            .update(cx, |project, cx| {
395                project.open_local_buffer("/project/src/main.rs", cx)
396            })
397            .await
398            .unwrap();
399
400        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
401        ep_store.update(cx, |ep_store, cx| {
402            ep_store.register_buffer(&buffer, &project, cx)
403        });
404        cx.run_until_parked();
405
406        buffer.update(cx, |buffer, cx| {
407            let point = Point::new(6, 0);
408            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
409            let point = Point::new(4, 0);
410            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
411
412            pretty_assertions::assert_eq!(
413                buffer.text(),
414                indoc! {"
415                    fn main() {
416                        // comment 1
417                        one();
418                        two();
419                        // comment 4
420                        three();
421                        four();
422                        // comment 3
423                        five();
424                        six();
425                        seven();
426                        eight();
427                        // comment 2
428                        nine();
429                    }
430                "}
431            );
432        });
433        cx.run_until_parked();
434
435        // Open and edit an external file (outside the main project's worktree)
436        let external_buffer = project
437            .update(cx, |project, cx| {
438                project.open_local_buffer("/external/external.rs", cx)
439            })
440            .await
441            .unwrap();
442        ep_store.update(cx, |ep_store, cx| {
443            ep_store.register_buffer(&external_buffer, &project, cx)
444        });
445        cx.run_until_parked();
446        external_buffer.update(cx, |buffer, cx| {
447            let point = Point::new(0, 0);
448            buffer.edit([(point..point, "// external edit\n")], None, cx);
449        });
450        cx.run_until_parked();
451
452        // Verify the external edit was recorded in events
453        let events = ep_store.update(cx, |store, cx| {
454            store.edit_history_for_project_with_pause_split_last_event(&project, cx)
455        });
456        assert!(
457            matches!(
458                events
459                    .last()
460                    .unwrap()
461                    .event
462                    .as_ref(),
463                zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
464            ),
465            "external file edit should be in events"
466        );
467
468        let mut example = cx
469            .update(|cx| {
470                capture_example(
471                    project.clone(),
472                    buffer.clone(),
473                    Anchor::MIN,
474                    events,
475                    Vec::new(),
476                    true,
477                    cx,
478                )
479                .unwrap()
480            })
481            .await
482            .unwrap();
483        example.name = "test".to_string();
484
485        pretty_assertions::assert_eq!(
486            example,
487            ExampleSpec {
488                name: "test".to_string(),
489                repository_url: "https://github.com/test/repo.git".to_string(),
490                revision: "abc123def456".to_string(),
491                tags: Vec::new(),
492                reasoning: None,
493                uncommitted_diff: indoc! {"
494                    --- a/src/main.rs
495                    +++ b/src/main.rs
496                    @@ -1,4 +1,5 @@
497                     fn main() {
498                    +    // comment 1
499                         one();
500                         two();
501                         three();
502                    @@ -7,5 +8,6 @@
503                         six();
504                         seven();
505                         eight();
506                    +    // comment 2
507                         nine();
508                     }
509                "}
510                .to_string(),
511                cursor_path: Path::new("src/main.rs").into(),
512                cursor_position: indoc! {"
513                    fn main() {
514                    ^[CURSOR_POSITION]
515                        // comment 1
516                        one();
517                        two();
518                        // comment 4
519                        three();
520                        four();
521                        // comment 3
522                        five();
523                        six();
524                        seven();
525                        eight();
526                        // comment 2
527                        nine();
528                    }
529                "}
530                .to_string(),
531                edit_history: indoc! {"
532                    --- a/src/main.rs
533                    +++ b/src/main.rs
534                    @@ -2,8 +2,10 @@
535                         // comment 1
536                         one();
537                         two();
538                    +    // comment 4
539                         three();
540                         four();
541                    +    // comment 3
542                         five();
543                         six();
544                         seven();
545                "}
546                .to_string(),
547                expected_patches: vec![
548                    indoc! {"
549                        --- a/src/main.rs
550                        +++ b/src/main.rs
551                        @@ -1,16 +1,16 @@
552                         fn main() {
553                             // comment 1
554                             one();
555                             two();
556                             // comment 4
557                             three();
558                             four();
559                             // comment 3
560                             five();
561                             six();
562                             seven();
563                             eight();
564                             // comment 2
565                             nine();
566                         }
567                    "}
568                    .to_string()
569                ],
570                rejected_patch: Some(
571                    indoc! {"
572                        --- a/src/main.rs
573                        +++ b/src/main.rs
574                        @@ -1,16 +1,16 @@
575                         fn main() {
576                             // comment 1
577                             one();
578                             two();
579                             // comment 4
580                             three();
581                             four();
582                             // comment 3
583                             five();
584                             six();
585                             seven();
586                             eight();
587                             // comment 2
588                             nine();
589                         }
590                    "}
591                    .to_string()
592                ),
593                captured_prompt_input: example.captured_prompt_input.clone(),
594                telemetry: None,
595                human_feedback: Vec::new(),
596                rating: None,
597            }
598        );
599
600        let prompt_input = example
601            .captured_prompt_input
602            .expect("should have captured prompt input");
603        assert!(
604            prompt_input.cursor_file_content.contains("fn main()"),
605            "cursor_file_content should contain file content"
606        );
607        assert_eq!(
608            prompt_input.cursor_offset, 0,
609            "cursor at Anchor::MIN should be offset 0"
610        );
611        assert_eq!(
612            prompt_input.cursor_row, 0,
613            "cursor at Anchor::MIN should be row 0"
614        );
615        assert_eq!(
616            prompt_input.cursor_column, 0,
617            "cursor at Anchor::MIN should be column 0"
618        );
619        assert!(prompt_input.events.len() > 0, "should have captured events");
620        assert_eq!(
621            prompt_input.related_files.len(),
622            0,
623            "should have no related files (none passed)"
624        );
625    }
626
627    fn init_test(cx: &mut TestAppContext) {
628        cx.update(|cx| {
629            let settings_store = SettingsStore::test(cx);
630            cx.set_global(settings_store);
631            zlog::init_test();
632            let http_client = FakeHttpClient::with_404_response();
633            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
634            language_model::init(client.clone(), cx);
635            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
636            EditPredictionStore::global(&client, &user_store, cx);
637        })
638    }
639}