capture_example.rs

  1use crate::{
  2    EditPredictionExampleCaptureFeatureFlag, StoredEvent,
  3    cursor_excerpt::editable_and_context_ranges_for_cursor_position,
  4    example_spec::{
  5        CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile,
  6        ExampleSpec, MAX_CURSOR_FILE_SIZE,
  7    },
  8};
  9use anyhow::Result;
 10use buffer_diff::BufferDiffSnapshot;
 11use collections::HashMap;
 12use feature_flags::FeatureFlagAppExt as _;
 13use gpui::{App, Entity, Task};
 14use language::{Buffer, ToPoint as _};
 15use project::{Project, WorktreeId};
 16use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
 17use text::{BufferSnapshot as TextBufferSnapshot, Point, ToOffset as _};
 18
 19pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
 20pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
 21pub(crate) const ZETA2_TESTING_RATE_PER_10K_PREDICTION: u16 = 500;
 22
 23pub fn capture_example(
 24    project: Entity<Project>,
 25    buffer: Entity<Buffer>,
 26    cursor_anchor: language::Anchor,
 27    mut events: Vec<StoredEvent>,
 28    related_files: Vec<zeta_prompt::RelatedFile>,
 29    populate_expected_patch: bool,
 30    cx: &mut App,
 31) -> Option<Task<Result<ExampleSpec>>> {
 32    let snapshot = buffer.read(cx).snapshot();
 33    let file = snapshot.file()?;
 34    let worktree_id = file.worktree_id(cx);
 35    let repository = project.read(cx).active_repository(cx)?;
 36    let repository_snapshot = repository.read(cx).snapshot();
 37    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 38    let root_name = worktree.read(cx).root_name_str().to_owned();
 39    let cursor_path: Arc<Path> = file.path().as_std_path().into();
 40    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 41        return None;
 42    }
 43
 44    let repository_url = repository_snapshot
 45        .remote_origin_url
 46        .clone()
 47        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 48    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 49
 50    let git_store = project.read(cx).git_store().clone();
 51
 52    Some(cx.spawn(async move |mut cx| {
 53        let snapshots_by_path =
 54            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 55
 56        events.retain(|stored_event| {
 57            let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
 58            let relative_path = strip_root_name(path, &root_name);
 59            snapshots_by_path.contains_key(relative_path)
 60        });
 61
 62        let line_comment_prefix = snapshot
 63            .language()
 64            .and_then(|lang| lang.config().line_comments.first())
 65            .map(|s| s.to_string())
 66            .unwrap_or_default();
 67
 68        let full_cursor_offset = cursor_anchor.to_offset(&snapshot);
 69        let cursor_point = cursor_anchor.to_point(&snapshot);
 70        let cursor_file_content = if snapshot.len() <= MAX_CURSOR_FILE_SIZE {
 71            Some(snapshot.text())
 72        } else {
 73            None
 74        };
 75
 76        let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
 77            .background_executor()
 78            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 79            .await;
 80        let uncommitted_diff = cx
 81            .background_executor()
 82            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 83            .await;
 84
 85        let mut edit_history = String::new();
 86        for stored_event in &events {
 87            write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
 88            if !edit_history.ends_with('\n') {
 89                edit_history.push('\n');
 90            }
 91        }
 92
 93        // Initialize an empty patch with context lines, to make it easy
 94        // to write the expected patch by hand.
 95        let mut expected_patches = Vec::new();
 96        let mut rejected_patch = None;
 97        if populate_expected_patch {
 98            let mut empty_patch = String::new();
 99            let start_row = cursor_excerpt_range.start.row + 1;
100            let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
101            writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
102            writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
103            writeln!(
104                &mut empty_patch,
105                "@@ -{},{} +{},{} @@",
106                start_row, row_count, start_row, row_count,
107            )
108            .ok();
109            for line in cursor_excerpt.lines() {
110                writeln!(&mut empty_patch, " {}", line).ok();
111            }
112
113            expected_patches.push(empty_patch.clone());
114            rejected_patch = Some(empty_patch);
115        }
116
117        let prompt_input = cursor_file_content.map(|content| {
118            let captured_events: Vec<CapturedEvent> = events
119                .iter()
120                .map(|stored_event| {
121                    let zeta_prompt::Event::BufferChange {
122                        path,
123                        old_path,
124                        diff,
125                        predicted,
126                        in_open_source_repo,
127                    } = stored_event.event.as_ref();
128                    CapturedEvent {
129                        path: strip_root_name(path, &root_name).into(),
130                        old_path: strip_root_name(old_path, &root_name).into(),
131                        diff: diff.clone(),
132                        predicted: *predicted,
133                        in_open_source_repo: *in_open_source_repo,
134                    }
135                })
136                .collect();
137
138            let captured_related_files: Vec<CapturedRelatedFile> = related_files
139                .iter()
140                .map(|rf| CapturedRelatedFile {
141                    path: strip_root_name(&rf.path, &root_name).into(),
142                    max_row: rf.max_row,
143                    excerpts: rf
144                        .excerpts
145                        .iter()
146                        .map(|e| CapturedRelatedExcerpt {
147                            row_range: e.row_range.clone(),
148                            text: e.text.to_string(),
149                        })
150                        .collect(),
151                })
152                .collect();
153
154            CapturedPromptInput {
155                cursor_file_content: content,
156                cursor_offset: full_cursor_offset,
157                cursor_row: cursor_point.row,
158                cursor_column: cursor_point.column,
159                excerpt_start_row: Some(0),
160                events: captured_events,
161                related_files: captured_related_files,
162            }
163        });
164
165        let mut spec = ExampleSpec {
166            name: generate_timestamp_name(),
167            repository_url,
168            revision,
169            tags: Vec::new(),
170            reasoning: None,
171            uncommitted_diff,
172            cursor_path,
173            cursor_position: String::new(),
174            edit_history,
175            expected_patches,
176            rejected_patch,
177            captured_prompt_input: prompt_input,
178            telemetry: None,
179            human_feedback: Vec::new(),
180            rating: None,
181        };
182        spec.set_cursor_excerpt(
183            &cursor_excerpt,
184            cursor_offset_in_excerpt,
185            &line_comment_prefix,
186        );
187        Ok(spec)
188    }))
189}
190
191fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
192    path.strip_prefix(root_name).unwrap_or(path)
193}
194
195fn write_event_with_relative_paths(
196    output: &mut String,
197    event: &zeta_prompt::Event,
198    root_name: &str,
199) {
200    fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
201        for component in strip_root_name(path, root_name).components() {
202            output.push('/');
203            write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
204        }
205    }
206
207    let zeta_prompt::Event::BufferChange {
208        path,
209        old_path,
210        diff,
211        ..
212    } = event;
213
214    output.push_str("--- a");
215    write_relative_path(output, old_path.as_ref(), root_name);
216    output.push_str("\n+++ b");
217    write_relative_path(output, path.as_ref(), root_name);
218    output.push('\n');
219    output.push_str(diff);
220}
221
222fn compute_cursor_excerpt(
223    snapshot: &language::BufferSnapshot,
224    cursor_anchor: language::Anchor,
225) -> (String, usize, Range<Point>) {
226    use text::ToOffset as _;
227
228    let cursor_point = cursor_anchor.to_point(snapshot);
229    let (_editable_range, context_range) =
230        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
231    let context_start_offset = context_range.start.to_offset(snapshot);
232    let cursor_offset = cursor_anchor.to_offset(snapshot);
233    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
234    let excerpt = snapshot
235        .text_for_range(context_range.clone())
236        .collect::<String>();
237    (excerpt, cursor_offset_in_excerpt, context_range)
238}
239
240async fn collect_snapshots(
241    project: &Entity<Project>,
242    git_store: &Entity<project::git_store::GitStore>,
243    worktree_id: WorktreeId,
244    events: &[StoredEvent],
245    cx: &mut gpui::AsyncApp,
246) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
247    let mut snapshots_by_path = HashMap::default();
248    for stored_event in events {
249        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
250        if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
251            let project_path = project
252                .find_project_path(path, cx)
253                .filter(|path| path.worktree_id == worktree_id)?;
254            let relative_path: Arc<Path> = project_path.path.as_std_path().into();
255            Some((project_path, relative_path))
256        }) {
257            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
258                let buffer = project
259                    .update(cx, |project, cx| {
260                        project.open_buffer(project_path.clone(), cx)
261                    })
262                    .await?;
263                let diff = git_store
264                    .update(cx, |git_store, cx| {
265                        git_store.open_uncommitted_diff(buffer.clone(), cx)
266                    })
267                    .await?;
268                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
269                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
270            }
271        }
272    }
273    Ok(snapshots_by_path)
274}
275
276fn compute_uncommitted_diff(
277    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
278) -> String {
279    let mut uncommitted_diff = String::new();
280    for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
281        if let Some(head_text) = &diff_snapshot.base_text_string() {
282            let file_diff = language::unified_diff(head_text, &before_text.text());
283            if !file_diff.is_empty() {
284                let path_str = relative_path.to_string_lossy();
285                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
286                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
287                uncommitted_diff.push_str(&file_diff);
288                if !uncommitted_diff.ends_with('\n') {
289                    uncommitted_diff.push('\n');
290                }
291            }
292        }
293    }
294    uncommitted_diff
295}
296
297fn generate_timestamp_name() -> String {
298    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
299    match format {
300        Ok(format) => {
301            let now = time::OffsetDateTime::now_local()
302                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
303            now.format(&format)
304                .unwrap_or_else(|_| "unknown-time".to_string())
305        }
306        Err(_) => "unknown-time".to_string(),
307    }
308}
309
310pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
311    let default_rate = if cx.is_staff() {
312        DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
313    } else {
314        DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
315    };
316    let capture_rate = language::language_settings::all_language_settings(None, cx)
317        .edit_predictions
318        .example_capture_rate
319        .unwrap_or(default_rate);
320    cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
321        && rand::random::<u16>() % 10_000 < capture_rate
322}
323
324pub(crate) fn should_send_testing_zeta2_request() -> bool {
325    rand::random::<u16>() % 10_000 < ZETA2_TESTING_RATE_PER_10K_PREDICTION
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use crate::EditPredictionStore;
332    use client::{Client, UserStore};
333    use clock::FakeSystemClock;
334    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
335    use indoc::indoc;
336    use language::{Anchor, Point};
337    use project::{FakeFs, Project};
338    use serde_json::json;
339    use settings::SettingsStore;
340    use std::path::Path;
341
342    #[gpui::test]
343    async fn test_capture_example(cx: &mut TestAppContext) {
344        init_test(cx);
345        let fs = FakeFs::new(cx.executor());
346
347        let committed_contents = indoc! {"
348            fn main() {
349                one();
350                two();
351                three();
352                four();
353                five();
354                six();
355                seven();
356                eight();
357                nine();
358            }
359        "};
360
361        let disk_contents = indoc! {"
362            fn main() {
363                // comment 1
364                one();
365                two();
366                three();
367                four();
368                five();
369                six();
370                seven();
371                eight();
372                // comment 2
373                nine();
374            }
375        "};
376
377        fs.insert_tree(
378            "/project",
379            json!({
380                ".git": {},
381                "src": {
382                    "main.rs": disk_contents,
383                }
384            }),
385        )
386        .await;
387
388        // Create an external file outside the main project
389        fs.insert_tree(
390            "/external",
391            json!({
392                "external.rs": "fn external() {}\n",
393            }),
394        )
395        .await;
396
397        fs.set_head_for_repo(
398            Path::new("/project/.git"),
399            &[("src/main.rs", committed_contents.to_string())],
400            "abc123def456",
401        );
402        fs.set_remote_for_repo(
403            Path::new("/project/.git"),
404            "origin",
405            "https://github.com/test/repo.git",
406        );
407
408        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
409
410        let buffer = project
411            .update(cx, |project, cx| {
412                project.open_local_buffer("/project/src/main.rs", cx)
413            })
414            .await
415            .unwrap();
416
417        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
418        ep_store.update(cx, |ep_store, cx| {
419            ep_store.register_buffer(&buffer, &project, cx)
420        });
421        cx.run_until_parked();
422
423        buffer.update(cx, |buffer, cx| {
424            let point = Point::new(6, 0);
425            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
426            let point = Point::new(4, 0);
427            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
428
429            pretty_assertions::assert_eq!(
430                buffer.text(),
431                indoc! {"
432                    fn main() {
433                        // comment 1
434                        one();
435                        two();
436                        // comment 4
437                        three();
438                        four();
439                        // comment 3
440                        five();
441                        six();
442                        seven();
443                        eight();
444                        // comment 2
445                        nine();
446                    }
447                "}
448            );
449        });
450        cx.run_until_parked();
451
452        // Open and edit an external file (outside the main project's worktree)
453        let external_buffer = project
454            .update(cx, |project, cx| {
455                project.open_local_buffer("/external/external.rs", cx)
456            })
457            .await
458            .unwrap();
459        ep_store.update(cx, |ep_store, cx| {
460            ep_store.register_buffer(&external_buffer, &project, cx)
461        });
462        cx.run_until_parked();
463        external_buffer.update(cx, |buffer, cx| {
464            let point = Point::new(0, 0);
465            buffer.edit([(point..point, "// external edit\n")], None, cx);
466        });
467        cx.run_until_parked();
468
469        // Verify the external edit was recorded in events
470        let events = ep_store.update(cx, |store, cx| {
471            store.edit_history_for_project_with_pause_split_last_event(&project, cx)
472        });
473        assert!(
474            matches!(
475                events
476                    .last()
477                    .unwrap()
478                    .event
479                    .as_ref(),
480                zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
481            ),
482            "external file edit should be in events"
483        );
484
485        let mut example = cx
486            .update(|cx| {
487                capture_example(
488                    project.clone(),
489                    buffer.clone(),
490                    Anchor::MIN,
491                    events,
492                    Vec::new(),
493                    true,
494                    cx,
495                )
496                .unwrap()
497            })
498            .await
499            .unwrap();
500        example.name = "test".to_string();
501
502        pretty_assertions::assert_eq!(
503            example,
504            ExampleSpec {
505                name: "test".to_string(),
506                repository_url: "https://github.com/test/repo.git".to_string(),
507                revision: "abc123def456".to_string(),
508                tags: Vec::new(),
509                reasoning: None,
510                uncommitted_diff: indoc! {"
511                    --- a/src/main.rs
512                    +++ b/src/main.rs
513                    @@ -1,4 +1,5 @@
514                     fn main() {
515                    +    // comment 1
516                         one();
517                         two();
518                         three();
519                    @@ -7,5 +8,6 @@
520                         six();
521                         seven();
522                         eight();
523                    +    // comment 2
524                         nine();
525                     }
526                "}
527                .to_string(),
528                cursor_path: Path::new("src/main.rs").into(),
529                cursor_position: indoc! {"
530                    fn main() {
531                    ^[CURSOR_POSITION]
532                        // comment 1
533                        one();
534                        two();
535                        // comment 4
536                        three();
537                        four();
538                        // comment 3
539                        five();
540                        six();
541                        seven();
542                        eight();
543                        // comment 2
544                        nine();
545                    }
546                "}
547                .to_string(),
548                edit_history: indoc! {"
549                    --- a/src/main.rs
550                    +++ b/src/main.rs
551                    @@ -2,8 +2,10 @@
552                         // comment 1
553                         one();
554                         two();
555                    +    // comment 4
556                         three();
557                         four();
558                    +    // comment 3
559                         five();
560                         six();
561                         seven();
562                "}
563                .to_string(),
564                expected_patches: vec![
565                    indoc! {"
566                        --- a/src/main.rs
567                        +++ b/src/main.rs
568                        @@ -1,16 +1,16 @@
569                         fn main() {
570                             // comment 1
571                             one();
572                             two();
573                             // comment 4
574                             three();
575                             four();
576                             // comment 3
577                             five();
578                             six();
579                             seven();
580                             eight();
581                             // comment 2
582                             nine();
583                         }
584                    "}
585                    .to_string()
586                ],
587                rejected_patch: Some(
588                    indoc! {"
589                        --- a/src/main.rs
590                        +++ b/src/main.rs
591                        @@ -1,16 +1,16 @@
592                         fn main() {
593                             // comment 1
594                             one();
595                             two();
596                             // comment 4
597                             three();
598                             four();
599                             // comment 3
600                             five();
601                             six();
602                             seven();
603                             eight();
604                             // comment 2
605                             nine();
606                         }
607                    "}
608                    .to_string()
609                ),
610                captured_prompt_input: example.captured_prompt_input.clone(),
611                telemetry: None,
612                human_feedback: Vec::new(),
613                rating: None,
614            }
615        );
616
617        let prompt_input = example
618            .captured_prompt_input
619            .expect("should have captured prompt input");
620        assert!(
621            prompt_input.cursor_file_content.contains("fn main()"),
622            "cursor_file_content should contain file content"
623        );
624        assert_eq!(
625            prompt_input.cursor_offset, 0,
626            "cursor at Anchor::MIN should be offset 0"
627        );
628        assert_eq!(
629            prompt_input.cursor_row, 0,
630            "cursor at Anchor::MIN should be row 0"
631        );
632        assert_eq!(
633            prompt_input.cursor_column, 0,
634            "cursor at Anchor::MIN should be column 0"
635        );
636        assert!(prompt_input.events.len() > 0, "should have captured events");
637        assert_eq!(
638            prompt_input.related_files.len(),
639            0,
640            "should have no related files (none passed)"
641        );
642    }
643
644    fn init_test(cx: &mut TestAppContext) {
645        cx.update(|cx| {
646            let settings_store = SettingsStore::test(cx);
647            cx.set_global(settings_store);
648            zlog::init_test();
649            let http_client = FakeHttpClient::with_404_response();
650            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
651            language_model::init(client.clone(), cx);
652            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
653            EditPredictionStore::global(&client, &user_store, cx);
654        })
655    }
656}