capture_example.rs

  1use crate::{
  2    EditPredictionExampleCaptureFeatureFlag, StoredEvent,
  3    cursor_excerpt::editable_and_context_ranges_for_cursor_position,
  4    example_spec::{
  5        CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile,
  6        ExampleSpec, MAX_CURSOR_FILE_SIZE,
  7    },
  8};
  9use anyhow::Result;
 10use buffer_diff::BufferDiffSnapshot;
 11use collections::HashMap;
 12use feature_flags::FeatureFlagAppExt as _;
 13use gpui::{App, Entity, Task};
 14use language::{Buffer, ToPoint as _};
 15use project::{Project, WorktreeId};
 16use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
 17use text::{BufferSnapshot as TextBufferSnapshot, Point, ToOffset as _};
 18
 19pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
 20pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
 21pub(crate) const ZETA2_TESTING_RATE_PER_10K_PREDICTION: u16 = 500;
 22
 23pub fn capture_example(
 24    project: Entity<Project>,
 25    buffer: Entity<Buffer>,
 26    cursor_anchor: language::Anchor,
 27    mut events: Vec<StoredEvent>,
 28    related_files: Vec<zeta_prompt::RelatedFile>,
 29    populate_expected_patch: bool,
 30    cx: &mut App,
 31) -> Option<Task<Result<ExampleSpec>>> {
 32    let snapshot = buffer.read(cx).snapshot();
 33    let file = snapshot.file()?;
 34    let worktree_id = file.worktree_id(cx);
 35    let repository = project.read(cx).active_repository(cx)?;
 36    let repository_snapshot = repository.read(cx).snapshot();
 37    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 38    let root_name = worktree.read(cx).root_name_str().to_owned();
 39    let cursor_path: Arc<Path> = file.path().as_std_path().into();
 40    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 41        return None;
 42    }
 43
 44    let repository_url = repository_snapshot
 45        .remote_origin_url
 46        .clone()
 47        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 48    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 49
 50    let git_store = project.read(cx).git_store().clone();
 51
 52    Some(cx.spawn(async move |mut cx| {
 53        let snapshots_by_path =
 54            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 55
 56        events.retain(|stored_event| {
 57            let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
 58            let relative_path = strip_root_name(path, &root_name);
 59            snapshots_by_path.contains_key(relative_path)
 60        });
 61
 62        let line_comment_prefix = snapshot
 63            .language()
 64            .and_then(|lang| lang.config().line_comments.first())
 65            .map(|s| s.to_string())
 66            .unwrap_or_default();
 67
 68        let full_cursor_offset = cursor_anchor.to_offset(&snapshot);
 69        let cursor_point = cursor_anchor.to_point(&snapshot);
 70        let cursor_file_content = if snapshot.len() <= MAX_CURSOR_FILE_SIZE {
 71            Some(snapshot.text())
 72        } else {
 73            None
 74        };
 75
 76        let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
 77            .background_executor()
 78            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 79            .await;
 80        let uncommitted_diff = cx
 81            .background_executor()
 82            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 83            .await;
 84
 85        let mut edit_history = String::new();
 86        for stored_event in &events {
 87            write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
 88            if !edit_history.ends_with('\n') {
 89                edit_history.push('\n');
 90            }
 91        }
 92
 93        // Initialize an empty patch with context lines, to make it easy
 94        // to write the expected patch by hand.
 95        let mut expected_patches = Vec::new();
 96        let mut rejected_patch = None;
 97        if populate_expected_patch {
 98            let mut empty_patch = String::new();
 99            let start_row = cursor_excerpt_range.start.row + 1;
100            let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
101            writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
102            writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
103            writeln!(
104                &mut empty_patch,
105                "@@ -{},{} +{},{} @@",
106                start_row, row_count, start_row, row_count,
107            )
108            .ok();
109            for line in cursor_excerpt.lines() {
110                writeln!(&mut empty_patch, " {}", line).ok();
111            }
112
113            expected_patches.push(empty_patch.clone());
114            rejected_patch = Some(empty_patch);
115        }
116
117        let prompt_input = cursor_file_content.map(|content| {
118            let captured_events: Vec<CapturedEvent> = events
119                .iter()
120                .map(|stored_event| {
121                    let zeta_prompt::Event::BufferChange {
122                        path,
123                        old_path,
124                        diff,
125                        predicted,
126                        in_open_source_repo,
127                    } = stored_event.event.as_ref();
128                    CapturedEvent {
129                        path: strip_root_name(path, &root_name).into(),
130                        old_path: strip_root_name(old_path, &root_name).into(),
131                        diff: diff.clone(),
132                        predicted: *predicted,
133                        in_open_source_repo: *in_open_source_repo,
134                    }
135                })
136                .collect();
137
138            let captured_related_files: Vec<CapturedRelatedFile> = related_files
139                .iter()
140                .map(|rf| CapturedRelatedFile {
141                    path: strip_root_name(&rf.path, &root_name).into(),
142                    max_row: rf.max_row,
143                    excerpts: rf
144                        .excerpts
145                        .iter()
146                        .map(|e| CapturedRelatedExcerpt {
147                            row_range: e.row_range.clone(),
148                            text: e.text.to_string(),
149                        })
150                        .collect(),
151                })
152                .collect();
153
154            CapturedPromptInput {
155                cursor_file_content: content,
156                cursor_offset: full_cursor_offset,
157                cursor_row: cursor_point.row,
158                cursor_column: cursor_point.column,
159                events: captured_events,
160                related_files: captured_related_files,
161            }
162        });
163
164        let mut spec = ExampleSpec {
165            name: generate_timestamp_name(),
166            repository_url,
167            revision,
168            tags: Vec::new(),
169            reasoning: None,
170            uncommitted_diff,
171            cursor_path,
172            cursor_position: String::new(),
173            edit_history,
174            expected_patches,
175            rejected_patch,
176            captured_prompt_input: prompt_input,
177            telemetry: None,
178        };
179        spec.set_cursor_excerpt(
180            &cursor_excerpt,
181            cursor_offset_in_excerpt,
182            &line_comment_prefix,
183        );
184        Ok(spec)
185    }))
186}
187
188fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
189    path.strip_prefix(root_name).unwrap_or(path)
190}
191
192fn write_event_with_relative_paths(
193    output: &mut String,
194    event: &zeta_prompt::Event,
195    root_name: &str,
196) {
197    fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
198        for component in strip_root_name(path, root_name).components() {
199            output.push('/');
200            write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
201        }
202    }
203
204    let zeta_prompt::Event::BufferChange {
205        path,
206        old_path,
207        diff,
208        ..
209    } = event;
210
211    output.push_str("--- a");
212    write_relative_path(output, old_path.as_ref(), root_name);
213    output.push_str("\n+++ b");
214    write_relative_path(output, path.as_ref(), root_name);
215    output.push('\n');
216    output.push_str(diff);
217}
218
219fn compute_cursor_excerpt(
220    snapshot: &language::BufferSnapshot,
221    cursor_anchor: language::Anchor,
222) -> (String, usize, Range<Point>) {
223    use text::ToOffset as _;
224
225    let cursor_point = cursor_anchor.to_point(snapshot);
226    let (_editable_range, context_range) =
227        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
228    let context_start_offset = context_range.start.to_offset(snapshot);
229    let cursor_offset = cursor_anchor.to_offset(snapshot);
230    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
231    let excerpt = snapshot
232        .text_for_range(context_range.clone())
233        .collect::<String>();
234    (excerpt, cursor_offset_in_excerpt, context_range)
235}
236
237async fn collect_snapshots(
238    project: &Entity<Project>,
239    git_store: &Entity<project::git_store::GitStore>,
240    worktree_id: WorktreeId,
241    events: &[StoredEvent],
242    cx: &mut gpui::AsyncApp,
243) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
244    let mut snapshots_by_path = HashMap::default();
245    for stored_event in events {
246        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
247        if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
248            let project_path = project
249                .find_project_path(path, cx)
250                .filter(|path| path.worktree_id == worktree_id)?;
251            let relative_path: Arc<Path> = project_path.path.as_std_path().into();
252            Some((project_path, relative_path))
253        }) {
254            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
255                let buffer = project
256                    .update(cx, |project, cx| {
257                        project.open_buffer(project_path.clone(), cx)
258                    })
259                    .await?;
260                let diff = git_store
261                    .update(cx, |git_store, cx| {
262                        git_store.open_uncommitted_diff(buffer.clone(), cx)
263                    })
264                    .await?;
265                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
266                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
267            }
268        }
269    }
270    Ok(snapshots_by_path)
271}
272
273fn compute_uncommitted_diff(
274    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
275) -> String {
276    let mut uncommitted_diff = String::new();
277    for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
278        if let Some(head_text) = &diff_snapshot.base_text_string() {
279            let file_diff = language::unified_diff(head_text, &before_text.text());
280            if !file_diff.is_empty() {
281                let path_str = relative_path.to_string_lossy();
282                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
283                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
284                uncommitted_diff.push_str(&file_diff);
285                if !uncommitted_diff.ends_with('\n') {
286                    uncommitted_diff.push('\n');
287                }
288            }
289        }
290    }
291    uncommitted_diff
292}
293
294fn generate_timestamp_name() -> String {
295    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
296    match format {
297        Ok(format) => {
298            let now = time::OffsetDateTime::now_local()
299                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
300            now.format(&format)
301                .unwrap_or_else(|_| "unknown-time".to_string())
302        }
303        Err(_) => "unknown-time".to_string(),
304    }
305}
306
307pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
308    let default_rate = if cx.is_staff() {
309        DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
310    } else {
311        DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
312    };
313    let capture_rate = language::language_settings::all_language_settings(None, cx)
314        .edit_predictions
315        .example_capture_rate
316        .unwrap_or(default_rate);
317    cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
318        && rand::random::<u16>() % 10_000 < capture_rate
319}
320
321pub(crate) fn should_send_testing_zeta2_request() -> bool {
322    rand::random::<u16>() % 10_000 < ZETA2_TESTING_RATE_PER_10K_PREDICTION
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use crate::EditPredictionStore;
329    use client::{Client, UserStore};
330    use clock::FakeSystemClock;
331    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
332    use indoc::indoc;
333    use language::{Anchor, Point};
334    use project::{FakeFs, Project};
335    use serde_json::json;
336    use settings::SettingsStore;
337    use std::path::Path;
338
339    #[gpui::test]
340    async fn test_capture_example(cx: &mut TestAppContext) {
341        init_test(cx);
342        let fs = FakeFs::new(cx.executor());
343
344        let committed_contents = indoc! {"
345            fn main() {
346                one();
347                two();
348                three();
349                four();
350                five();
351                six();
352                seven();
353                eight();
354                nine();
355            }
356        "};
357
358        let disk_contents = indoc! {"
359            fn main() {
360                // comment 1
361                one();
362                two();
363                three();
364                four();
365                five();
366                six();
367                seven();
368                eight();
369                // comment 2
370                nine();
371            }
372        "};
373
374        fs.insert_tree(
375            "/project",
376            json!({
377                ".git": {},
378                "src": {
379                    "main.rs": disk_contents,
380                }
381            }),
382        )
383        .await;
384
385        // Create an external file outside the main project
386        fs.insert_tree(
387            "/external",
388            json!({
389                "external.rs": "fn external() {}\n",
390            }),
391        )
392        .await;
393
394        fs.set_head_for_repo(
395            Path::new("/project/.git"),
396            &[("src/main.rs", committed_contents.to_string())],
397            "abc123def456",
398        );
399        fs.set_remote_for_repo(
400            Path::new("/project/.git"),
401            "origin",
402            "https://github.com/test/repo.git",
403        );
404
405        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
406
407        let buffer = project
408            .update(cx, |project, cx| {
409                project.open_local_buffer("/project/src/main.rs", cx)
410            })
411            .await
412            .unwrap();
413
414        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
415        ep_store.update(cx, |ep_store, cx| {
416            ep_store.register_buffer(&buffer, &project, cx)
417        });
418        cx.run_until_parked();
419
420        buffer.update(cx, |buffer, cx| {
421            let point = Point::new(6, 0);
422            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
423            let point = Point::new(4, 0);
424            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
425
426            pretty_assertions::assert_eq!(
427                buffer.text(),
428                indoc! {"
429                    fn main() {
430                        // comment 1
431                        one();
432                        two();
433                        // comment 4
434                        three();
435                        four();
436                        // comment 3
437                        five();
438                        six();
439                        seven();
440                        eight();
441                        // comment 2
442                        nine();
443                    }
444                "}
445            );
446        });
447        cx.run_until_parked();
448
449        // Open and edit an external file (outside the main project's worktree)
450        let external_buffer = project
451            .update(cx, |project, cx| {
452                project.open_local_buffer("/external/external.rs", cx)
453            })
454            .await
455            .unwrap();
456        ep_store.update(cx, |ep_store, cx| {
457            ep_store.register_buffer(&external_buffer, &project, cx)
458        });
459        cx.run_until_parked();
460        external_buffer.update(cx, |buffer, cx| {
461            let point = Point::new(0, 0);
462            buffer.edit([(point..point, "// external edit\n")], None, cx);
463        });
464        cx.run_until_parked();
465
466        // Verify the external edit was recorded in events
467        let events = ep_store.update(cx, |store, cx| {
468            store.edit_history_for_project_with_pause_split_last_event(&project, cx)
469        });
470        assert!(
471            matches!(
472                events
473                    .last()
474                    .unwrap()
475                    .event
476                    .as_ref(),
477                zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
478            ),
479            "external file edit should be in events"
480        );
481
482        let mut example = cx
483            .update(|cx| {
484                capture_example(
485                    project.clone(),
486                    buffer.clone(),
487                    Anchor::MIN,
488                    events,
489                    Vec::new(),
490                    true,
491                    cx,
492                )
493                .unwrap()
494            })
495            .await
496            .unwrap();
497        example.name = "test".to_string();
498
499        pretty_assertions::assert_eq!(
500            example,
501            ExampleSpec {
502                name: "test".to_string(),
503                repository_url: "https://github.com/test/repo.git".to_string(),
504                revision: "abc123def456".to_string(),
505                tags: Vec::new(),
506                reasoning: None,
507                uncommitted_diff: indoc! {"
508                    --- a/src/main.rs
509                    +++ b/src/main.rs
510                    @@ -1,4 +1,5 @@
511                     fn main() {
512                    +    // comment 1
513                         one();
514                         two();
515                         three();
516                    @@ -7,5 +8,6 @@
517                         six();
518                         seven();
519                         eight();
520                    +    // comment 2
521                         nine();
522                     }
523                "}
524                .to_string(),
525                cursor_path: Path::new("src/main.rs").into(),
526                cursor_position: indoc! {"
527                    fn main() {
528                    ^[CURSOR_POSITION]
529                        // comment 1
530                        one();
531                        two();
532                        // comment 4
533                        three();
534                        four();
535                        // comment 3
536                        five();
537                        six();
538                        seven();
539                        eight();
540                        // comment 2
541                        nine();
542                    }
543                "}
544                .to_string(),
545                edit_history: indoc! {"
546                    --- a/src/main.rs
547                    +++ b/src/main.rs
548                    @@ -2,8 +2,10 @@
549                         // comment 1
550                         one();
551                         two();
552                    +    // comment 4
553                         three();
554                         four();
555                    +    // comment 3
556                         five();
557                         six();
558                         seven();
559                "}
560                .to_string(),
561                expected_patches: vec![
562                    indoc! {"
563                        --- a/src/main.rs
564                        +++ b/src/main.rs
565                        @@ -1,16 +1,16 @@
566                         fn main() {
567                             // comment 1
568                             one();
569                             two();
570                             // comment 4
571                             three();
572                             four();
573                             // comment 3
574                             five();
575                             six();
576                             seven();
577                             eight();
578                             // comment 2
579                             nine();
580                         }
581                    "}
582                    .to_string()
583                ],
584                rejected_patch: Some(
585                    indoc! {"
586                        --- a/src/main.rs
587                        +++ b/src/main.rs
588                        @@ -1,16 +1,16 @@
589                         fn main() {
590                             // comment 1
591                             one();
592                             two();
593                             // comment 4
594                             three();
595                             four();
596                             // comment 3
597                             five();
598                             six();
599                             seven();
600                             eight();
601                             // comment 2
602                             nine();
603                         }
604                    "}
605                    .to_string()
606                ),
607                captured_prompt_input: example.captured_prompt_input.clone(),
608                telemetry: None,
609            }
610        );
611
612        let prompt_input = example
613            .captured_prompt_input
614            .expect("should have captured prompt input");
615        assert!(
616            prompt_input.cursor_file_content.contains("fn main()"),
617            "cursor_file_content should contain file content"
618        );
619        assert_eq!(
620            prompt_input.cursor_offset, 0,
621            "cursor at Anchor::MIN should be offset 0"
622        );
623        assert_eq!(
624            prompt_input.cursor_row, 0,
625            "cursor at Anchor::MIN should be row 0"
626        );
627        assert_eq!(
628            prompt_input.cursor_column, 0,
629            "cursor at Anchor::MIN should be column 0"
630        );
631        assert!(prompt_input.events.len() > 0, "should have captured events");
632        assert_eq!(
633            prompt_input.related_files.len(),
634            0,
635            "should have no related files (none passed)"
636        );
637    }
638
639    fn init_test(cx: &mut TestAppContext) {
640        cx.update(|cx| {
641            let settings_store = SettingsStore::test(cx);
642            cx.set_global(settings_store);
643            zlog::init_test();
644            let http_client = FakeHttpClient::with_404_response();
645            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
646            language_model::init(client.clone(), cx);
647            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
648            EditPredictionStore::global(&client, &user_store, cx);
649        })
650    }
651}