capture_example.rs

  1use crate::{
  2    EditPredictionExampleCaptureFeatureFlag, StoredEvent,
  3    cursor_excerpt::editable_and_context_ranges_for_cursor_position,
  4    example_spec::{
  5        CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile,
  6        ExampleSpec, MAX_CURSOR_FILE_SIZE,
  7    },
  8};
  9use anyhow::Result;
 10use buffer_diff::BufferDiffSnapshot;
 11use collections::HashMap;
 12use feature_flags::FeatureFlagAppExt as _;
 13use gpui::{App, Entity, Task};
 14use language::{Buffer, ToPoint as _};
 15use project::{Project, WorktreeId};
 16use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
 17use text::{BufferSnapshot as TextBufferSnapshot, Point, ToOffset as _};
 18
 19pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
 20pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
 21
 22pub fn capture_example(
 23    project: Entity<Project>,
 24    buffer: Entity<Buffer>,
 25    cursor_anchor: language::Anchor,
 26    mut events: Vec<StoredEvent>,
 27    related_files: Vec<zeta_prompt::RelatedFile>,
 28    populate_expected_patch: bool,
 29    cx: &mut App,
 30) -> Option<Task<Result<ExampleSpec>>> {
 31    let snapshot = buffer.read(cx).snapshot();
 32    let file = snapshot.file()?;
 33    let worktree_id = file.worktree_id(cx);
 34    let repository = project.read(cx).active_repository(cx)?;
 35    let repository_snapshot = repository.read(cx).snapshot();
 36    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 37    let root_name = worktree.read(cx).root_name_str().to_owned();
 38    let cursor_path: Arc<Path> = file.path().as_std_path().into();
 39    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 40        return None;
 41    }
 42
 43    let repository_url = repository_snapshot
 44        .remote_origin_url
 45        .clone()
 46        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 47    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 48
 49    let git_store = project.read(cx).git_store().clone();
 50
 51    Some(cx.spawn(async move |mut cx| {
 52        let snapshots_by_path =
 53            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 54
 55        events.retain(|stored_event| {
 56            let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
 57            let relative_path = strip_root_name(path, &root_name);
 58            snapshots_by_path.contains_key(relative_path)
 59        });
 60
 61        let line_comment_prefix = snapshot
 62            .language()
 63            .and_then(|lang| lang.config().line_comments.first())
 64            .map(|s| s.to_string())
 65            .unwrap_or_default();
 66
 67        let full_cursor_offset = cursor_anchor.to_offset(&snapshot);
 68        let cursor_point = cursor_anchor.to_point(&snapshot);
 69        let cursor_file_content = if snapshot.len() <= MAX_CURSOR_FILE_SIZE {
 70            Some(snapshot.text())
 71        } else {
 72            None
 73        };
 74
 75        let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
 76            .background_executor()
 77            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 78            .await;
 79        let uncommitted_diff = cx
 80            .background_executor()
 81            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 82            .await;
 83
 84        let mut edit_history = String::new();
 85        for stored_event in &events {
 86            write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
 87            if !edit_history.ends_with('\n') {
 88                edit_history.push('\n');
 89            }
 90        }
 91
 92        // Initialize an empty patch with context lines, to make it easy
 93        // to write the expected patch by hand.
 94        let mut expected_patches = Vec::new();
 95        let mut rejected_patch = None;
 96        if populate_expected_patch {
 97            let mut empty_patch = String::new();
 98            let start_row = cursor_excerpt_range.start.row + 1;
 99            let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
100            writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
101            writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
102            writeln!(
103                &mut empty_patch,
104                "@@ -{},{} +{},{} @@",
105                start_row, row_count, start_row, row_count,
106            )
107            .ok();
108            for line in cursor_excerpt.lines() {
109                writeln!(&mut empty_patch, " {}", line).ok();
110            }
111
112            expected_patches.push(empty_patch.clone());
113            rejected_patch = Some(empty_patch);
114        }
115
116        let prompt_input = cursor_file_content.map(|content| {
117            let captured_events: Vec<CapturedEvent> = events
118                .iter()
119                .map(|stored_event| {
120                    let zeta_prompt::Event::BufferChange {
121                        path,
122                        old_path,
123                        diff,
124                        predicted,
125                        in_open_source_repo,
126                    } = stored_event.event.as_ref();
127                    CapturedEvent {
128                        path: strip_root_name(path, &root_name).into(),
129                        old_path: strip_root_name(old_path, &root_name).into(),
130                        diff: diff.clone(),
131                        predicted: *predicted,
132                        in_open_source_repo: *in_open_source_repo,
133                    }
134                })
135                .collect();
136
137            let captured_related_files: Vec<CapturedRelatedFile> = related_files
138                .iter()
139                .map(|rf| CapturedRelatedFile {
140                    path: strip_root_name(&rf.path, &root_name).into(),
141                    max_row: rf.max_row,
142                    excerpts: rf
143                        .excerpts
144                        .iter()
145                        .map(|e| CapturedRelatedExcerpt {
146                            row_range: e.row_range.clone(),
147                            text: e.text.to_string(),
148                        })
149                        .collect(),
150                })
151                .collect();
152
153            CapturedPromptInput {
154                cursor_file_content: content,
155                cursor_offset: full_cursor_offset,
156                cursor_row: cursor_point.row,
157                cursor_column: cursor_point.column,
158                events: captured_events,
159                related_files: captured_related_files,
160            }
161        });
162
163        let mut spec = ExampleSpec {
164            name: generate_timestamp_name(),
165            repository_url,
166            revision,
167            tags: Vec::new(),
168            reasoning: None,
169            uncommitted_diff,
170            cursor_path,
171            cursor_position: String::new(),
172            edit_history,
173            expected_patches,
174            rejected_patch,
175            captured_prompt_input: prompt_input,
176            telemetry: None,
177        };
178        spec.set_cursor_excerpt(
179            &cursor_excerpt,
180            cursor_offset_in_excerpt,
181            &line_comment_prefix,
182        );
183        Ok(spec)
184    }))
185}
186
187fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
188    path.strip_prefix(root_name).unwrap_or(path)
189}
190
191fn write_event_with_relative_paths(
192    output: &mut String,
193    event: &zeta_prompt::Event,
194    root_name: &str,
195) {
196    fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
197        for component in strip_root_name(path, root_name).components() {
198            output.push('/');
199            write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
200        }
201    }
202
203    let zeta_prompt::Event::BufferChange {
204        path,
205        old_path,
206        diff,
207        ..
208    } = event;
209
210    output.push_str("--- a");
211    write_relative_path(output, old_path.as_ref(), root_name);
212    output.push_str("\n+++ b");
213    write_relative_path(output, path.as_ref(), root_name);
214    output.push('\n');
215    output.push_str(diff);
216}
217
218fn compute_cursor_excerpt(
219    snapshot: &language::BufferSnapshot,
220    cursor_anchor: language::Anchor,
221) -> (String, usize, Range<Point>) {
222    use text::ToOffset as _;
223
224    let cursor_point = cursor_anchor.to_point(snapshot);
225    let (_editable_range, context_range) =
226        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
227    let context_start_offset = context_range.start.to_offset(snapshot);
228    let cursor_offset = cursor_anchor.to_offset(snapshot);
229    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
230    let excerpt = snapshot
231        .text_for_range(context_range.clone())
232        .collect::<String>();
233    (excerpt, cursor_offset_in_excerpt, context_range)
234}
235
236async fn collect_snapshots(
237    project: &Entity<Project>,
238    git_store: &Entity<project::git_store::GitStore>,
239    worktree_id: WorktreeId,
240    events: &[StoredEvent],
241    cx: &mut gpui::AsyncApp,
242) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
243    let mut snapshots_by_path = HashMap::default();
244    for stored_event in events {
245        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
246        if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
247            let project_path = project
248                .find_project_path(path, cx)
249                .filter(|path| path.worktree_id == worktree_id)?;
250            let relative_path: Arc<Path> = project_path.path.as_std_path().into();
251            Some((project_path, relative_path))
252        }) {
253            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
254                let buffer = project
255                    .update(cx, |project, cx| {
256                        project.open_buffer(project_path.clone(), cx)
257                    })
258                    .await?;
259                let diff = git_store
260                    .update(cx, |git_store, cx| {
261                        git_store.open_uncommitted_diff(buffer.clone(), cx)
262                    })
263                    .await?;
264                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
265                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
266            }
267        }
268    }
269    Ok(snapshots_by_path)
270}
271
272fn compute_uncommitted_diff(
273    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
274) -> String {
275    let mut uncommitted_diff = String::new();
276    for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
277        if let Some(head_text) = &diff_snapshot.base_text_string() {
278            let file_diff = language::unified_diff(head_text, &before_text.text());
279            if !file_diff.is_empty() {
280                let path_str = relative_path.to_string_lossy();
281                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
282                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
283                uncommitted_diff.push_str(&file_diff);
284                if !uncommitted_diff.ends_with('\n') {
285                    uncommitted_diff.push('\n');
286                }
287            }
288        }
289    }
290    uncommitted_diff
291}
292
293fn generate_timestamp_name() -> String {
294    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
295    match format {
296        Ok(format) => {
297            let now = time::OffsetDateTime::now_local()
298                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
299            now.format(&format)
300                .unwrap_or_else(|_| "unknown-time".to_string())
301        }
302        Err(_) => "unknown-time".to_string(),
303    }
304}
305
306pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
307    let default_rate = if cx.is_staff() {
308        DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
309    } else {
310        DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
311    };
312    let capture_rate = language::language_settings::all_language_settings(None, cx)
313        .edit_predictions
314        .example_capture_rate
315        .unwrap_or(default_rate);
316    cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
317        && rand::random::<u16>() % 10_000 < capture_rate
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::EditPredictionStore;
324    use client::{Client, UserStore};
325    use clock::FakeSystemClock;
326    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
327    use indoc::indoc;
328    use language::{Anchor, Point};
329    use project::{FakeFs, Project};
330    use serde_json::json;
331    use settings::SettingsStore;
332    use std::path::Path;
333
334    #[gpui::test]
335    async fn test_capture_example(cx: &mut TestAppContext) {
336        init_test(cx);
337        let fs = FakeFs::new(cx.executor());
338
339        let committed_contents = indoc! {"
340            fn main() {
341                one();
342                two();
343                three();
344                four();
345                five();
346                six();
347                seven();
348                eight();
349                nine();
350            }
351        "};
352
353        let disk_contents = indoc! {"
354            fn main() {
355                // comment 1
356                one();
357                two();
358                three();
359                four();
360                five();
361                six();
362                seven();
363                eight();
364                // comment 2
365                nine();
366            }
367        "};
368
369        fs.insert_tree(
370            "/project",
371            json!({
372                ".git": {},
373                "src": {
374                    "main.rs": disk_contents,
375                }
376            }),
377        )
378        .await;
379
380        // Create an external file outside the main project
381        fs.insert_tree(
382            "/external",
383            json!({
384                "external.rs": "fn external() {}\n",
385            }),
386        )
387        .await;
388
389        fs.set_head_for_repo(
390            Path::new("/project/.git"),
391            &[("src/main.rs", committed_contents.to_string())],
392            "abc123def456",
393        );
394        fs.set_remote_for_repo(
395            Path::new("/project/.git"),
396            "origin",
397            "https://github.com/test/repo.git",
398        );
399
400        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
401
402        let buffer = project
403            .update(cx, |project, cx| {
404                project.open_local_buffer("/project/src/main.rs", cx)
405            })
406            .await
407            .unwrap();
408
409        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
410        ep_store.update(cx, |ep_store, cx| {
411            ep_store.register_buffer(&buffer, &project, cx)
412        });
413        cx.run_until_parked();
414
415        buffer.update(cx, |buffer, cx| {
416            let point = Point::new(6, 0);
417            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
418            let point = Point::new(4, 0);
419            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
420
421            pretty_assertions::assert_eq!(
422                buffer.text(),
423                indoc! {"
424                    fn main() {
425                        // comment 1
426                        one();
427                        two();
428                        // comment 4
429                        three();
430                        four();
431                        // comment 3
432                        five();
433                        six();
434                        seven();
435                        eight();
436                        // comment 2
437                        nine();
438                    }
439                "}
440            );
441        });
442        cx.run_until_parked();
443
444        // Open and edit an external file (outside the main project's worktree)
445        let external_buffer = project
446            .update(cx, |project, cx| {
447                project.open_local_buffer("/external/external.rs", cx)
448            })
449            .await
450            .unwrap();
451        ep_store.update(cx, |ep_store, cx| {
452            ep_store.register_buffer(&external_buffer, &project, cx)
453        });
454        cx.run_until_parked();
455        external_buffer.update(cx, |buffer, cx| {
456            let point = Point::new(0, 0);
457            buffer.edit([(point..point, "// external edit\n")], None, cx);
458        });
459        cx.run_until_parked();
460
461        // Verify the external edit was recorded in events
462        let events = ep_store.update(cx, |store, cx| {
463            store.edit_history_for_project_with_pause_split_last_event(&project, cx)
464        });
465        assert!(
466            matches!(
467                events
468                    .last()
469                    .unwrap()
470                    .event
471                    .as_ref(),
472                zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
473            ),
474            "external file edit should be in events"
475        );
476
477        let mut example = cx
478            .update(|cx| {
479                capture_example(
480                    project.clone(),
481                    buffer.clone(),
482                    Anchor::MIN,
483                    events,
484                    Vec::new(),
485                    true,
486                    cx,
487                )
488                .unwrap()
489            })
490            .await
491            .unwrap();
492        example.name = "test".to_string();
493
494        pretty_assertions::assert_eq!(
495            example,
496            ExampleSpec {
497                name: "test".to_string(),
498                repository_url: "https://github.com/test/repo.git".to_string(),
499                revision: "abc123def456".to_string(),
500                tags: Vec::new(),
501                reasoning: None,
502                uncommitted_diff: indoc! {"
503                    --- a/src/main.rs
504                    +++ b/src/main.rs
505                    @@ -1,4 +1,5 @@
506                     fn main() {
507                    +    // comment 1
508                         one();
509                         two();
510                         three();
511                    @@ -7,5 +8,6 @@
512                         six();
513                         seven();
514                         eight();
515                    +    // comment 2
516                         nine();
517                     }
518                "}
519                .to_string(),
520                cursor_path: Path::new("src/main.rs").into(),
521                cursor_position: indoc! {"
522                    fn main() {
523                    ^[CURSOR_POSITION]
524                        // comment 1
525                        one();
526                        two();
527                        // comment 4
528                        three();
529                        four();
530                        // comment 3
531                        five();
532                        six();
533                        seven();
534                        eight();
535                        // comment 2
536                        nine();
537                    }
538                "}
539                .to_string(),
540                edit_history: indoc! {"
541                    --- a/src/main.rs
542                    +++ b/src/main.rs
543                    @@ -2,8 +2,10 @@
544                         // comment 1
545                         one();
546                         two();
547                    +    // comment 4
548                         three();
549                         four();
550                    +    // comment 3
551                         five();
552                         six();
553                         seven();
554                "}
555                .to_string(),
556                expected_patches: vec![
557                    indoc! {"
558                        --- a/src/main.rs
559                        +++ b/src/main.rs
560                        @@ -1,16 +1,16 @@
561                         fn main() {
562                             // comment 1
563                             one();
564                             two();
565                             // comment 4
566                             three();
567                             four();
568                             // comment 3
569                             five();
570                             six();
571                             seven();
572                             eight();
573                             // comment 2
574                             nine();
575                         }
576                    "}
577                    .to_string()
578                ],
579                rejected_patch: Some(
580                    indoc! {"
581                        --- a/src/main.rs
582                        +++ b/src/main.rs
583                        @@ -1,16 +1,16 @@
584                         fn main() {
585                             // comment 1
586                             one();
587                             two();
588                             // comment 4
589                             three();
590                             four();
591                             // comment 3
592                             five();
593                             six();
594                             seven();
595                             eight();
596                             // comment 2
597                             nine();
598                         }
599                    "}
600                    .to_string()
601                ),
602                captured_prompt_input: example.captured_prompt_input.clone(),
603                telemetry: None,
604            }
605        );
606
607        let prompt_input = example
608            .captured_prompt_input
609            .expect("should have captured prompt input");
610        assert!(
611            prompt_input.cursor_file_content.contains("fn main()"),
612            "cursor_file_content should contain file content"
613        );
614        assert_eq!(
615            prompt_input.cursor_offset, 0,
616            "cursor at Anchor::MIN should be offset 0"
617        );
618        assert_eq!(
619            prompt_input.cursor_row, 0,
620            "cursor at Anchor::MIN should be row 0"
621        );
622        assert_eq!(
623            prompt_input.cursor_column, 0,
624            "cursor at Anchor::MIN should be column 0"
625        );
626        assert!(prompt_input.events.len() > 0, "should have captured events");
627        assert_eq!(
628            prompt_input.related_files.len(),
629            0,
630            "should have no related files (none passed)"
631        );
632    }
633
634    fn init_test(cx: &mut TestAppContext) {
635        cx.update(|cx| {
636            let settings_store = SettingsStore::test(cx);
637            cx.set_global(settings_store);
638            zlog::init_test();
639            let http_client = FakeHttpClient::with_404_response();
640            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
641            language_model::init(client.clone(), cx);
642            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
643            EditPredictionStore::global(&client, &user_store, cx);
644        })
645    }
646}