capture_example.rs

  1use crate::{
  2    EditPredictionExampleCaptureFeatureFlag, StoredEvent,
  3    cursor_excerpt::editable_and_context_ranges_for_cursor_position,
  4    example_spec::{
  5        CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile,
  6        ExampleSpec, MAX_CURSOR_FILE_SIZE,
  7    },
  8};
  9use anyhow::Result;
 10use buffer_diff::BufferDiffSnapshot;
 11use collections::HashMap;
 12use feature_flags::FeatureFlagAppExt as _;
 13use gpui::{App, Entity, Task};
 14use language::{Buffer, ToPoint as _};
 15use project::{Project, WorktreeId};
 16use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
 17use text::{BufferSnapshot as TextBufferSnapshot, Point, ToOffset as _};
 18
 19pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
 20pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
 21
 22pub fn capture_example(
 23    project: Entity<Project>,
 24    buffer: Entity<Buffer>,
 25    cursor_anchor: language::Anchor,
 26    mut events: Vec<StoredEvent>,
 27    related_files: Vec<zeta_prompt::RelatedFile>,
 28    populate_expected_patch: bool,
 29    cx: &mut App,
 30) -> Option<Task<Result<ExampleSpec>>> {
 31    let snapshot = buffer.read(cx).snapshot();
 32    let file = snapshot.file()?;
 33    let worktree_id = file.worktree_id(cx);
 34    let repository = project.read(cx).active_repository(cx)?;
 35    let repository_snapshot = repository.read(cx).snapshot();
 36    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 37    let root_name = worktree.read(cx).root_name_str().to_owned();
 38    let cursor_path: Arc<Path> = file.path().as_std_path().into();
 39    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 40        return None;
 41    }
 42
 43    let repository_url = repository_snapshot
 44        .remote_origin_url
 45        .clone()
 46        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 47    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 48
 49    let git_store = project.read(cx).git_store().clone();
 50
 51    Some(cx.spawn(async move |mut cx| {
 52        let snapshots_by_path =
 53            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 54
 55        events.retain(|stored_event| {
 56            let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
 57            let relative_path = strip_root_name(path, &root_name);
 58            snapshots_by_path.contains_key(relative_path)
 59        });
 60
 61        let line_comment_prefix = snapshot
 62            .language()
 63            .and_then(|lang| lang.config().line_comments.first())
 64            .map(|s| s.to_string())
 65            .unwrap_or_default();
 66
 67        let full_cursor_offset = cursor_anchor.to_offset(&snapshot);
 68        let cursor_point = cursor_anchor.to_point(&snapshot);
 69        let cursor_file_content = if snapshot.len() <= MAX_CURSOR_FILE_SIZE {
 70            Some(snapshot.text())
 71        } else {
 72            None
 73        };
 74
 75        let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
 76            .background_executor()
 77            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 78            .await;
 79        let uncommitted_diff = cx
 80            .background_executor()
 81            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 82            .await;
 83
 84        let mut edit_history = String::new();
 85        for stored_event in &events {
 86            write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
 87            if !edit_history.ends_with('\n') {
 88                edit_history.push('\n');
 89            }
 90        }
 91
 92        // Initialize an empty patch with context lines, to make it easy
 93        // to write the expected patch by hand.
 94        let mut expected_patches = Vec::new();
 95        let mut rejected_patch = None;
 96        if populate_expected_patch {
 97            let mut empty_patch = String::new();
 98            let start_row = cursor_excerpt_range.start.row + 1;
 99            let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
100            writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
101            writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
102            writeln!(
103                &mut empty_patch,
104                "@@ -{},{} +{},{} @@",
105                start_row, row_count, start_row, row_count,
106            )
107            .ok();
108            for line in cursor_excerpt.lines() {
109                writeln!(&mut empty_patch, " {}", line).ok();
110            }
111
112            expected_patches.push(empty_patch.clone());
113            rejected_patch = Some(empty_patch);
114        }
115
116        let prompt_input = cursor_file_content.map(|content| {
117            let captured_events: Vec<CapturedEvent> = events
118                .iter()
119                .map(|stored_event| {
120                    let zeta_prompt::Event::BufferChange {
121                        path,
122                        old_path,
123                        diff,
124                        predicted,
125                        in_open_source_repo,
126                    } = stored_event.event.as_ref();
127                    CapturedEvent {
128                        path: strip_root_name(path, &root_name).into(),
129                        old_path: strip_root_name(old_path, &root_name).into(),
130                        diff: diff.clone(),
131                        predicted: *predicted,
132                        in_open_source_repo: *in_open_source_repo,
133                    }
134                })
135                .collect();
136
137            let captured_related_files: Vec<CapturedRelatedFile> = related_files
138                .iter()
139                .map(|rf| CapturedRelatedFile {
140                    path: strip_root_name(&rf.path, &root_name).into(),
141                    max_row: rf.max_row,
142                    excerpts: rf
143                        .excerpts
144                        .iter()
145                        .map(|e| CapturedRelatedExcerpt {
146                            row_range: e.row_range.clone(),
147                            text: e.text.to_string(),
148                        })
149                        .collect(),
150                })
151                .collect();
152
153            CapturedPromptInput {
154                cursor_file_content: content,
155                cursor_offset: full_cursor_offset,
156                cursor_row: cursor_point.row,
157                cursor_column: cursor_point.column,
158                events: captured_events,
159                related_files: captured_related_files,
160            }
161        });
162
163        let mut spec = ExampleSpec {
164            name: generate_timestamp_name(),
165            repository_url,
166            revision,
167            tags: Vec::new(),
168            reasoning: None,
169            uncommitted_diff,
170            cursor_path,
171            cursor_position: String::new(),
172            edit_history,
173            expected_patches,
174            rejected_patch,
175            captured_prompt_input: prompt_input,
176        };
177        spec.set_cursor_excerpt(
178            &cursor_excerpt,
179            cursor_offset_in_excerpt,
180            &line_comment_prefix,
181        );
182        Ok(spec)
183    }))
184}
185
186fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
187    path.strip_prefix(root_name).unwrap_or(path)
188}
189
190fn write_event_with_relative_paths(
191    output: &mut String,
192    event: &zeta_prompt::Event,
193    root_name: &str,
194) {
195    fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
196        for component in strip_root_name(path, root_name).components() {
197            output.push('/');
198            write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
199        }
200    }
201
202    let zeta_prompt::Event::BufferChange {
203        path,
204        old_path,
205        diff,
206        ..
207    } = event;
208
209    output.push_str("--- a");
210    write_relative_path(output, old_path.as_ref(), root_name);
211    output.push_str("\n+++ b");
212    write_relative_path(output, path.as_ref(), root_name);
213    output.push('\n');
214    output.push_str(diff);
215}
216
217fn compute_cursor_excerpt(
218    snapshot: &language::BufferSnapshot,
219    cursor_anchor: language::Anchor,
220) -> (String, usize, Range<Point>) {
221    use text::ToOffset as _;
222
223    let cursor_point = cursor_anchor.to_point(snapshot);
224    let (_editable_range, context_range) =
225        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
226    let context_start_offset = context_range.start.to_offset(snapshot);
227    let cursor_offset = cursor_anchor.to_offset(snapshot);
228    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
229    let excerpt = snapshot
230        .text_for_range(context_range.clone())
231        .collect::<String>();
232    (excerpt, cursor_offset_in_excerpt, context_range)
233}
234
235async fn collect_snapshots(
236    project: &Entity<Project>,
237    git_store: &Entity<project::git_store::GitStore>,
238    worktree_id: WorktreeId,
239    events: &[StoredEvent],
240    cx: &mut gpui::AsyncApp,
241) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
242    let mut snapshots_by_path = HashMap::default();
243    for stored_event in events {
244        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
245        if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
246            let project_path = project
247                .find_project_path(path, cx)
248                .filter(|path| path.worktree_id == worktree_id)?;
249            let relative_path: Arc<Path> = project_path.path.as_std_path().into();
250            Some((project_path, relative_path))
251        }) {
252            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
253                let buffer = project
254                    .update(cx, |project, cx| {
255                        project.open_buffer(project_path.clone(), cx)
256                    })
257                    .await?;
258                let diff = git_store
259                    .update(cx, |git_store, cx| {
260                        git_store.open_uncommitted_diff(buffer.clone(), cx)
261                    })
262                    .await?;
263                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
264                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
265            }
266        }
267    }
268    Ok(snapshots_by_path)
269}
270
271fn compute_uncommitted_diff(
272    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
273) -> String {
274    let mut uncommitted_diff = String::new();
275    for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
276        if let Some(head_text) = &diff_snapshot.base_text_string() {
277            let file_diff = language::unified_diff(head_text, &before_text.text());
278            if !file_diff.is_empty() {
279                let path_str = relative_path.to_string_lossy();
280                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
281                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
282                uncommitted_diff.push_str(&file_diff);
283                if !uncommitted_diff.ends_with('\n') {
284                    uncommitted_diff.push('\n');
285                }
286            }
287        }
288    }
289    uncommitted_diff
290}
291
292fn generate_timestamp_name() -> String {
293    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
294    match format {
295        Ok(format) => {
296            let now = time::OffsetDateTime::now_local()
297                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
298            now.format(&format)
299                .unwrap_or_else(|_| "unknown-time".to_string())
300        }
301        Err(_) => "unknown-time".to_string(),
302    }
303}
304
305pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
306    let default_rate = if cx.is_staff() {
307        DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
308    } else {
309        DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
310    };
311    let capture_rate = language::language_settings::all_language_settings(None, cx)
312        .edit_predictions
313        .example_capture_rate
314        .unwrap_or(default_rate);
315    cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
316        && rand::random::<u16>() % 10_000 < capture_rate
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use crate::EditPredictionStore;
323    use client::{Client, UserStore};
324    use clock::FakeSystemClock;
325    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
326    use indoc::indoc;
327    use language::{Anchor, Point};
328    use project::{FakeFs, Project};
329    use serde_json::json;
330    use settings::SettingsStore;
331    use std::path::Path;
332
333    #[gpui::test]
334    async fn test_capture_example(cx: &mut TestAppContext) {
335        init_test(cx);
336        let fs = FakeFs::new(cx.executor());
337
338        let committed_contents = indoc! {"
339            fn main() {
340                one();
341                two();
342                three();
343                four();
344                five();
345                six();
346                seven();
347                eight();
348                nine();
349            }
350        "};
351
352        let disk_contents = indoc! {"
353            fn main() {
354                // comment 1
355                one();
356                two();
357                three();
358                four();
359                five();
360                six();
361                seven();
362                eight();
363                // comment 2
364                nine();
365            }
366        "};
367
368        fs.insert_tree(
369            "/project",
370            json!({
371                ".git": {},
372                "src": {
373                    "main.rs": disk_contents,
374                }
375            }),
376        )
377        .await;
378
379        // Create an external file outside the main project
380        fs.insert_tree(
381            "/external",
382            json!({
383                "external.rs": "fn external() {}\n",
384            }),
385        )
386        .await;
387
388        fs.set_head_for_repo(
389            Path::new("/project/.git"),
390            &[("src/main.rs", committed_contents.to_string())],
391            "abc123def456",
392        );
393        fs.set_remote_for_repo(
394            Path::new("/project/.git"),
395            "origin",
396            "https://github.com/test/repo.git",
397        );
398
399        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
400
401        let buffer = project
402            .update(cx, |project, cx| {
403                project.open_local_buffer("/project/src/main.rs", cx)
404            })
405            .await
406            .unwrap();
407
408        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
409        ep_store.update(cx, |ep_store, cx| {
410            ep_store.register_buffer(&buffer, &project, cx)
411        });
412        cx.run_until_parked();
413
414        buffer.update(cx, |buffer, cx| {
415            let point = Point::new(6, 0);
416            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
417            let point = Point::new(4, 0);
418            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
419
420            pretty_assertions::assert_eq!(
421                buffer.text(),
422                indoc! {"
423                    fn main() {
424                        // comment 1
425                        one();
426                        two();
427                        // comment 4
428                        three();
429                        four();
430                        // comment 3
431                        five();
432                        six();
433                        seven();
434                        eight();
435                        // comment 2
436                        nine();
437                    }
438                "}
439            );
440        });
441        cx.run_until_parked();
442
443        // Open and edit an external file (outside the main project's worktree)
444        let external_buffer = project
445            .update(cx, |project, cx| {
446                project.open_local_buffer("/external/external.rs", cx)
447            })
448            .await
449            .unwrap();
450        ep_store.update(cx, |ep_store, cx| {
451            ep_store.register_buffer(&external_buffer, &project, cx)
452        });
453        cx.run_until_parked();
454        external_buffer.update(cx, |buffer, cx| {
455            let point = Point::new(0, 0);
456            buffer.edit([(point..point, "// external edit\n")], None, cx);
457        });
458        cx.run_until_parked();
459
460        // Verify the external edit was recorded in events
461        let events = ep_store.update(cx, |store, cx| {
462            store.edit_history_for_project_with_pause_split_last_event(&project, cx)
463        });
464        assert!(
465            matches!(
466                events
467                    .last()
468                    .unwrap()
469                    .event
470                    .as_ref(),
471                zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
472            ),
473            "external file edit should be in events"
474        );
475
476        let mut example = cx
477            .update(|cx| {
478                capture_example(
479                    project.clone(),
480                    buffer.clone(),
481                    Anchor::MIN,
482                    events,
483                    Vec::new(),
484                    true,
485                    cx,
486                )
487                .unwrap()
488            })
489            .await
490            .unwrap();
491        example.name = "test".to_string();
492
493        pretty_assertions::assert_eq!(
494            example,
495            ExampleSpec {
496                name: "test".to_string(),
497                repository_url: "https://github.com/test/repo.git".to_string(),
498                revision: "abc123def456".to_string(),
499                tags: Vec::new(),
500                reasoning: None,
501                uncommitted_diff: indoc! {"
502                    --- a/src/main.rs
503                    +++ b/src/main.rs
504                    @@ -1,4 +1,5 @@
505                     fn main() {
506                    +    // comment 1
507                         one();
508                         two();
509                         three();
510                    @@ -7,5 +8,6 @@
511                         six();
512                         seven();
513                         eight();
514                    +    // comment 2
515                         nine();
516                     }
517                "}
518                .to_string(),
519                cursor_path: Path::new("src/main.rs").into(),
520                cursor_position: indoc! {"
521                    fn main() {
522                    ^[CURSOR_POSITION]
523                        // comment 1
524                        one();
525                        two();
526                        // comment 4
527                        three();
528                        four();
529                        // comment 3
530                        five();
531                        six();
532                        seven();
533                        eight();
534                        // comment 2
535                        nine();
536                    }
537                "}
538                .to_string(),
539                edit_history: indoc! {"
540                    --- a/src/main.rs
541                    +++ b/src/main.rs
542                    @@ -2,8 +2,10 @@
543                         // comment 1
544                         one();
545                         two();
546                    +    // comment 4
547                         three();
548                         four();
549                    +    // comment 3
550                         five();
551                         six();
552                         seven();
553                "}
554                .to_string(),
555                expected_patches: vec![
556                    indoc! {"
557                        --- a/src/main.rs
558                        +++ b/src/main.rs
559                        @@ -1,16 +1,16 @@
560                         fn main() {
561                             // comment 1
562                             one();
563                             two();
564                             // comment 4
565                             three();
566                             four();
567                             // comment 3
568                             five();
569                             six();
570                             seven();
571                             eight();
572                             // comment 2
573                             nine();
574                         }
575                    "}
576                    .to_string()
577                ],
578                rejected_patch: Some(
579                    indoc! {"
580                        --- a/src/main.rs
581                        +++ b/src/main.rs
582                        @@ -1,16 +1,16 @@
583                         fn main() {
584                             // comment 1
585                             one();
586                             two();
587                             // comment 4
588                             three();
589                             four();
590                             // comment 3
591                             five();
592                             six();
593                             seven();
594                             eight();
595                             // comment 2
596                             nine();
597                         }
598                    "}
599                    .to_string()
600                ),
601                captured_prompt_input: example.captured_prompt_input.clone(),
602            }
603        );
604
605        let prompt_input = example
606            .captured_prompt_input
607            .expect("should have captured prompt input");
608        assert!(
609            prompt_input.cursor_file_content.contains("fn main()"),
610            "cursor_file_content should contain file content"
611        );
612        assert_eq!(
613            prompt_input.cursor_offset, 0,
614            "cursor at Anchor::MIN should be offset 0"
615        );
616        assert_eq!(
617            prompt_input.cursor_row, 0,
618            "cursor at Anchor::MIN should be row 0"
619        );
620        assert_eq!(
621            prompt_input.cursor_column, 0,
622            "cursor at Anchor::MIN should be column 0"
623        );
624        assert!(prompt_input.events.len() > 0, "should have captured events");
625        assert_eq!(
626            prompt_input.related_files.len(),
627            0,
628            "should have no related files (none passed)"
629        );
630    }
631
632    fn init_test(cx: &mut TestAppContext) {
633        cx.update(|cx| {
634            let settings_store = SettingsStore::test(cx);
635            cx.set_global(settings_store);
636            zlog::init_test();
637            let http_client = FakeHttpClient::with_404_response();
638            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
639            language_model::init(client.clone(), cx);
640            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
641            EditPredictionStore::global(&client, &user_store, cx);
642        })
643    }
644}