capture_example.rs

  1use crate::{
  2    EditPredictionExampleCaptureFeatureFlag, StoredEvent,
  3    cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
  4};
  5use anyhow::Result;
  6use buffer_diff::BufferDiffSnapshot;
  7use collections::HashMap;
  8use feature_flags::FeatureFlagAppExt as _;
  9use gpui::{App, Entity, Task};
 10use language::{Buffer, ToPoint as _};
 11use project::{Project, WorktreeId};
 12use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
 13use text::{BufferSnapshot as TextBufferSnapshot, Point};
 14
 15pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
 16pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
 17
 18pub fn capture_example(
 19    project: Entity<Project>,
 20    buffer: Entity<Buffer>,
 21    cursor_anchor: language::Anchor,
 22    mut events: Vec<StoredEvent>,
 23    populate_expected_patch: bool,
 24    cx: &mut App,
 25) -> Option<Task<Result<ExampleSpec>>> {
 26    let snapshot = buffer.read(cx).snapshot();
 27    let file = snapshot.file()?;
 28    let worktree_id = file.worktree_id(cx);
 29    let repository = project.read(cx).active_repository(cx)?;
 30    let repository_snapshot = repository.read(cx).snapshot();
 31    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 32    let root_name = worktree.read(cx).root_name_str().to_owned();
 33    let cursor_path: Arc<Path> = file.path().as_std_path().into();
 34    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 35        return None;
 36    }
 37
 38    let repository_url = repository_snapshot
 39        .remote_origin_url
 40        .clone()
 41        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 42    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 43
 44    let git_store = project.read(cx).git_store().clone();
 45
 46    Some(cx.spawn(async move |mut cx| {
 47        let snapshots_by_path =
 48            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 49
 50        events.retain(|stored_event| {
 51            let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
 52            let relative_path = strip_root_name(path, &root_name);
 53            snapshots_by_path.contains_key(relative_path)
 54        });
 55
 56        let line_comment_prefix = snapshot
 57            .language()
 58            .and_then(|lang| lang.config().line_comments.first())
 59            .map(|s| s.to_string())
 60            .unwrap_or_default();
 61        let (cursor_excerpt, cursor_offset, cursor_excerpt_range) = cx
 62            .background_executor()
 63            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 64            .await;
 65        let uncommitted_diff = cx
 66            .background_executor()
 67            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 68            .await;
 69
 70        let mut edit_history = String::new();
 71        for stored_event in &events {
 72            write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
 73            if !edit_history.ends_with('\n') {
 74                edit_history.push('\n');
 75            }
 76        }
 77
 78        // Initialize an empty patch with context lines, to make it easy
 79        // to write the expected patch by hand.
 80        let mut expected_patches = Vec::new();
 81        let mut rejected_patch = None;
 82        if populate_expected_patch {
 83            let mut empty_patch = String::new();
 84            let start_row = cursor_excerpt_range.start.row + 1;
 85            let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
 86            writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
 87            writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
 88            writeln!(
 89                &mut empty_patch,
 90                "@@ -{},{} +{},{} @@",
 91                start_row, row_count, start_row, row_count,
 92            )
 93            .ok();
 94            for line in cursor_excerpt.lines() {
 95                writeln!(&mut empty_patch, " {}", line).ok();
 96            }
 97
 98            expected_patches.push(empty_patch.clone());
 99            rejected_patch = Some(empty_patch);
100        }
101
102        let mut spec = ExampleSpec {
103            name: generate_timestamp_name(),
104            repository_url,
105            revision,
106            tags: Vec::new(),
107            reasoning: None,
108            uncommitted_diff,
109            cursor_path,
110            cursor_position: String::new(),
111            edit_history,
112            expected_patches,
113            rejected_patch,
114        };
115        spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
116        Ok(spec)
117    }))
118}
119
120fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
121    path.strip_prefix(root_name).unwrap_or(path)
122}
123
124fn write_event_with_relative_paths(
125    output: &mut String,
126    event: &zeta_prompt::Event,
127    root_name: &str,
128) {
129    fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
130        for component in strip_root_name(path, root_name).components() {
131            output.push('/');
132            write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
133        }
134    }
135
136    let zeta_prompt::Event::BufferChange {
137        path,
138        old_path,
139        diff,
140        ..
141    } = event;
142
143    output.push_str("--- a");
144    write_relative_path(output, old_path.as_ref(), root_name);
145    output.push_str("\n+++ b");
146    write_relative_path(output, path.as_ref(), root_name);
147    output.push('\n');
148    output.push_str(diff);
149}
150
151fn compute_cursor_excerpt(
152    snapshot: &language::BufferSnapshot,
153    cursor_anchor: language::Anchor,
154) -> (String, usize, Range<Point>) {
155    use text::ToOffset as _;
156
157    let cursor_point = cursor_anchor.to_point(snapshot);
158    let (_editable_range, context_range) =
159        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
160    let context_start_offset = context_range.start.to_offset(snapshot);
161    let cursor_offset = cursor_anchor.to_offset(snapshot);
162    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
163    let excerpt = snapshot
164        .text_for_range(context_range.clone())
165        .collect::<String>();
166    (excerpt, cursor_offset_in_excerpt, context_range)
167}
168
169async fn collect_snapshots(
170    project: &Entity<Project>,
171    git_store: &Entity<project::git_store::GitStore>,
172    worktree_id: WorktreeId,
173    events: &[StoredEvent],
174    cx: &mut gpui::AsyncApp,
175) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
176    let mut snapshots_by_path = HashMap::default();
177    for stored_event in events {
178        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
179        if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
180            let project_path = project
181                .find_project_path(path, cx)
182                .filter(|path| path.worktree_id == worktree_id)?;
183            let relative_path: Arc<Path> = project_path.path.as_std_path().into();
184            Some((project_path, relative_path))
185        }) {
186            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
187                let buffer = project
188                    .update(cx, |project, cx| {
189                        project.open_buffer(project_path.clone(), cx)
190                    })
191                    .await?;
192                let diff = git_store
193                    .update(cx, |git_store, cx| {
194                        git_store.open_uncommitted_diff(buffer.clone(), cx)
195                    })
196                    .await?;
197                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
198                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
199            }
200        }
201    }
202    Ok(snapshots_by_path)
203}
204
205fn compute_uncommitted_diff(
206    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
207) -> String {
208    let mut uncommitted_diff = String::new();
209    for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
210        if let Some(head_text) = &diff_snapshot.base_text_string() {
211            let file_diff = language::unified_diff(head_text, &before_text.text());
212            if !file_diff.is_empty() {
213                let path_str = relative_path.to_string_lossy();
214                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
215                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
216                uncommitted_diff.push_str(&file_diff);
217                if !uncommitted_diff.ends_with('\n') {
218                    uncommitted_diff.push('\n');
219                }
220            }
221        }
222    }
223    uncommitted_diff
224}
225
226fn generate_timestamp_name() -> String {
227    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
228    match format {
229        Ok(format) => {
230            let now = time::OffsetDateTime::now_local()
231                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
232            now.format(&format)
233                .unwrap_or_else(|_| "unknown-time".to_string())
234        }
235        Err(_) => "unknown-time".to_string(),
236    }
237}
238
239pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
240    let default_rate = if cx.is_staff() {
241        DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
242    } else {
243        DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
244    };
245    let capture_rate = language::language_settings::all_language_settings(None, cx)
246        .edit_predictions
247        .example_capture_rate
248        .unwrap_or(default_rate);
249    cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
250        && rand::random::<u16>() % 10_000 < capture_rate
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use crate::EditPredictionStore;
257    use client::{Client, UserStore};
258    use clock::FakeSystemClock;
259    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
260    use indoc::indoc;
261    use language::{Anchor, Point};
262    use project::{FakeFs, Project};
263    use serde_json::json;
264    use settings::SettingsStore;
265    use std::path::Path;
266
267    #[gpui::test]
268    async fn test_capture_example(cx: &mut TestAppContext) {
269        init_test(cx);
270        let fs = FakeFs::new(cx.executor());
271
272        let committed_contents = indoc! {"
273            fn main() {
274                one();
275                two();
276                three();
277                four();
278                five();
279                six();
280                seven();
281                eight();
282                nine();
283            }
284        "};
285
286        let disk_contents = indoc! {"
287            fn main() {
288                // comment 1
289                one();
290                two();
291                three();
292                four();
293                five();
294                six();
295                seven();
296                eight();
297                // comment 2
298                nine();
299            }
300        "};
301
302        fs.insert_tree(
303            "/project",
304            json!({
305                ".git": {},
306                "src": {
307                    "main.rs": disk_contents,
308                }
309            }),
310        )
311        .await;
312
313        // Create an external file outside the main project
314        fs.insert_tree(
315            "/external",
316            json!({
317                "external.rs": "fn external() {}\n",
318            }),
319        )
320        .await;
321
322        fs.set_head_for_repo(
323            Path::new("/project/.git"),
324            &[("src/main.rs", committed_contents.to_string())],
325            "abc123def456",
326        );
327        fs.set_remote_for_repo(
328            Path::new("/project/.git"),
329            "origin",
330            "https://github.com/test/repo.git",
331        );
332
333        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
334
335        let buffer = project
336            .update(cx, |project, cx| {
337                project.open_local_buffer("/project/src/main.rs", cx)
338            })
339            .await
340            .unwrap();
341
342        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
343        ep_store.update(cx, |ep_store, cx| {
344            ep_store.register_buffer(&buffer, &project, cx)
345        });
346        cx.run_until_parked();
347
348        buffer.update(cx, |buffer, cx| {
349            let point = Point::new(6, 0);
350            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
351            let point = Point::new(4, 0);
352            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
353
354            pretty_assertions::assert_eq!(
355                buffer.text(),
356                indoc! {"
357                    fn main() {
358                        // comment 1
359                        one();
360                        two();
361                        // comment 4
362                        three();
363                        four();
364                        // comment 3
365                        five();
366                        six();
367                        seven();
368                        eight();
369                        // comment 2
370                        nine();
371                    }
372                "}
373            );
374        });
375        cx.run_until_parked();
376
377        // Open and edit an external file (outside the main project's worktree)
378        let external_buffer = project
379            .update(cx, |project, cx| {
380                project.open_local_buffer("/external/external.rs", cx)
381            })
382            .await
383            .unwrap();
384        ep_store.update(cx, |ep_store, cx| {
385            ep_store.register_buffer(&external_buffer, &project, cx)
386        });
387        cx.run_until_parked();
388        external_buffer.update(cx, |buffer, cx| {
389            let point = Point::new(0, 0);
390            buffer.edit([(point..point, "// external edit\n")], None, cx);
391        });
392        cx.run_until_parked();
393
394        // Verify the external edit was recorded in events
395        let events = ep_store.update(cx, |store, cx| {
396            store.edit_history_for_project_with_pause_split_last_event(&project, cx)
397        });
398        assert!(
399            matches!(
400                events
401                    .last()
402                    .unwrap()
403                    .event
404                    .as_ref(),
405                zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
406            ),
407            "external file edit should be in events"
408        );
409
410        let mut example = cx
411            .update(|cx| {
412                capture_example(
413                    project.clone(),
414                    buffer.clone(),
415                    Anchor::MIN,
416                    events,
417                    true,
418                    cx,
419                )
420                .unwrap()
421            })
422            .await
423            .unwrap();
424        example.name = "test".to_string();
425
426        pretty_assertions::assert_eq!(
427            example,
428            ExampleSpec {
429                name: "test".to_string(),
430                repository_url: "https://github.com/test/repo.git".to_string(),
431                revision: "abc123def456".to_string(),
432                tags: Vec::new(),
433                reasoning: None,
434                uncommitted_diff: indoc! {"
435                    --- a/src/main.rs
436                    +++ b/src/main.rs
437                    @@ -1,4 +1,5 @@
438                     fn main() {
439                    +    // comment 1
440                         one();
441                         two();
442                         three();
443                    @@ -7,5 +8,6 @@
444                         six();
445                         seven();
446                         eight();
447                    +    // comment 2
448                         nine();
449                     }
450                "}
451                .to_string(),
452                cursor_path: Path::new("src/main.rs").into(),
453                cursor_position: indoc! {"
454                    fn main() {
455                    ^[CURSOR_POSITION]
456                        // comment 1
457                        one();
458                        two();
459                        // comment 4
460                        three();
461                        four();
462                        // comment 3
463                        five();
464                        six();
465                        seven();
466                        eight();
467                        // comment 2
468                        nine();
469                    }
470                "}
471                .to_string(),
472                edit_history: indoc! {"
473                    --- a/src/main.rs
474                    +++ b/src/main.rs
475                    @@ -2,8 +2,10 @@
476                         // comment 1
477                         one();
478                         two();
479                    +    // comment 4
480                         three();
481                         four();
482                    +    // comment 3
483                         five();
484                         six();
485                         seven();
486                "}
487                .to_string(),
488                expected_patches: vec![
489                    indoc! {"
490                        --- a/src/main.rs
491                        +++ b/src/main.rs
492                        @@ -1,16 +1,16 @@
493                         fn main() {
494                             // comment 1
495                             one();
496                             two();
497                             // comment 4
498                             three();
499                             four();
500                             // comment 3
501                             five();
502                             six();
503                             seven();
504                             eight();
505                             // comment 2
506                             nine();
507                         }
508                    "}
509                    .to_string()
510                ],
511                rejected_patch: Some(
512                    indoc! {"
513                        --- a/src/main.rs
514                        +++ b/src/main.rs
515                        @@ -1,16 +1,16 @@
516                         fn main() {
517                             // comment 1
518                             one();
519                             two();
520                             // comment 4
521                             three();
522                             four();
523                             // comment 3
524                             five();
525                             six();
526                             seven();
527                             eight();
528                             // comment 2
529                             nine();
530                         }
531                    "}
532                    .to_string()
533                )
534            }
535        );
536    }
537
538    fn init_test(cx: &mut TestAppContext) {
539        cx.update(|cx| {
540            let settings_store = SettingsStore::test(cx);
541            cx.set_global(settings_store);
542            zlog::init_test();
543            let http_client = FakeHttpClient::with_404_response();
544            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
545            language_model::init(client.clone(), cx);
546            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
547            EditPredictionStore::global(&client, &user_store, cx);
548        })
549    }
550}