capture_example.rs

  1use crate::{
  2    EditPredictionExampleCaptureFeatureFlag, StoredEvent,
  3    cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
  4};
  5use anyhow::Result;
  6use buffer_diff::BufferDiffSnapshot;
  7use collections::HashMap;
  8use feature_flags::FeatureFlagAppExt as _;
  9use gpui::{App, Entity, Task};
 10use language::{Buffer, ToPoint as _};
 11use project::{Project, WorktreeId};
 12use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
 13use text::{BufferSnapshot as TextBufferSnapshot, Point};
 14
 15pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
 16pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
 17
 18pub fn capture_example(
 19    project: Entity<Project>,
 20    buffer: Entity<Buffer>,
 21    cursor_anchor: language::Anchor,
 22    mut events: Vec<StoredEvent>,
 23    populate_expected_patch: bool,
 24    cx: &mut App,
 25) -> Option<Task<Result<ExampleSpec>>> {
 26    let snapshot = buffer.read(cx).snapshot();
 27    let file = snapshot.file()?;
 28    let worktree_id = file.worktree_id(cx);
 29    let repository = project.read(cx).active_repository(cx)?;
 30    let repository_snapshot = repository.read(cx).snapshot();
 31    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 32    let root_name = worktree.read(cx).root_name_str().to_owned();
 33    let cursor_path: Arc<Path> = file.path().as_std_path().into();
 34    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 35        return None;
 36    }
 37
 38    let repository_url = repository_snapshot
 39        .remote_origin_url
 40        .clone()
 41        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 42    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 43
 44    let git_store = project.read(cx).git_store().clone();
 45
 46    Some(cx.spawn(async move |mut cx| {
 47        let snapshots_by_path =
 48            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 49
 50        events.retain(|stored_event| {
 51            let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
 52            let relative_path = strip_root_name(path, &root_name);
 53            snapshots_by_path.contains_key(relative_path)
 54        });
 55
 56        let line_comment_prefix = snapshot
 57            .language()
 58            .and_then(|lang| lang.config().line_comments.first())
 59            .map(|s| s.to_string())
 60            .unwrap_or_default();
 61        let (cursor_excerpt, cursor_offset, cursor_excerpt_range) = cx
 62            .background_executor()
 63            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 64            .await;
 65        let uncommitted_diff = cx
 66            .background_executor()
 67            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 68            .await;
 69
 70        let mut edit_history = String::new();
 71        for stored_event in &events {
 72            write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
 73            if !edit_history.ends_with('\n') {
 74                edit_history.push('\n');
 75            }
 76        }
 77
 78        // Initialize an empty patch with context lines, to make it easy
 79        // to write the expected patch by hand.
 80        let mut expected_patches = Vec::new();
 81        if populate_expected_patch {
 82            let mut empty_patch = String::new();
 83            let start_row = cursor_excerpt_range.start.row + 1;
 84            let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
 85            writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
 86            writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
 87            writeln!(
 88                &mut empty_patch,
 89                "@@ -{},{} +{},{} @@",
 90                start_row, row_count, start_row, row_count,
 91            )
 92            .ok();
 93            for line in cursor_excerpt.lines() {
 94                writeln!(&mut empty_patch, " {}", line).ok();
 95            }
 96            expected_patches.push(empty_patch);
 97        }
 98
 99        let mut spec = ExampleSpec {
100            name: generate_timestamp_name(),
101            repository_url,
102            revision,
103            tags: Vec::new(),
104            reasoning: None,
105            uncommitted_diff,
106            cursor_path,
107            cursor_position: String::new(),
108            edit_history,
109            expected_patches,
110        };
111        spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
112        Ok(spec)
113    }))
114}
115
116fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
117    path.strip_prefix(root_name).unwrap_or(path)
118}
119
120fn write_event_with_relative_paths(
121    output: &mut String,
122    event: &zeta_prompt::Event,
123    root_name: &str,
124) {
125    fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
126        for component in strip_root_name(path, root_name).components() {
127            output.push('/');
128            write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
129        }
130    }
131
132    let zeta_prompt::Event::BufferChange {
133        path,
134        old_path,
135        diff,
136        ..
137    } = event;
138
139    output.push_str("--- a");
140    write_relative_path(output, old_path.as_ref(), root_name);
141    output.push_str("\n+++ b");
142    write_relative_path(output, path.as_ref(), root_name);
143    output.push('\n');
144    output.push_str(diff);
145}
146
147fn compute_cursor_excerpt(
148    snapshot: &language::BufferSnapshot,
149    cursor_anchor: language::Anchor,
150) -> (String, usize, Range<Point>) {
151    use text::ToOffset as _;
152
153    let cursor_point = cursor_anchor.to_point(snapshot);
154    let (_editable_range, context_range) =
155        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
156    let context_start_offset = context_range.start.to_offset(snapshot);
157    let cursor_offset = cursor_anchor.to_offset(snapshot);
158    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
159    let excerpt = snapshot
160        .text_for_range(context_range.clone())
161        .collect::<String>();
162    (excerpt, cursor_offset_in_excerpt, context_range)
163}
164
165async fn collect_snapshots(
166    project: &Entity<Project>,
167    git_store: &Entity<project::git_store::GitStore>,
168    worktree_id: WorktreeId,
169    events: &[StoredEvent],
170    cx: &mut gpui::AsyncApp,
171) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
172    let mut snapshots_by_path = HashMap::default();
173    for stored_event in events {
174        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
175        if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
176            let project_path = project
177                .find_project_path(path, cx)
178                .filter(|path| path.worktree_id == worktree_id)?;
179            let relative_path: Arc<Path> = project_path.path.as_std_path().into();
180            Some((project_path, relative_path))
181        }) {
182            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
183                let buffer = project
184                    .update(cx, |project, cx| {
185                        project.open_buffer(project_path.clone(), cx)
186                    })
187                    .await?;
188                let diff = git_store
189                    .update(cx, |git_store, cx| {
190                        git_store.open_uncommitted_diff(buffer.clone(), cx)
191                    })
192                    .await?;
193                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
194                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
195            }
196        }
197    }
198    Ok(snapshots_by_path)
199}
200
201fn compute_uncommitted_diff(
202    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
203) -> String {
204    let mut uncommitted_diff = String::new();
205    for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
206        if let Some(head_text) = &diff_snapshot.base_text_string() {
207            let file_diff = language::unified_diff(head_text, &before_text.text());
208            if !file_diff.is_empty() {
209                let path_str = relative_path.to_string_lossy();
210                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
211                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
212                uncommitted_diff.push_str(&file_diff);
213                if !uncommitted_diff.ends_with('\n') {
214                    uncommitted_diff.push('\n');
215                }
216            }
217        }
218    }
219    uncommitted_diff
220}
221
222fn generate_timestamp_name() -> String {
223    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
224    match format {
225        Ok(format) => {
226            let now = time::OffsetDateTime::now_local()
227                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
228            now.format(&format)
229                .unwrap_or_else(|_| "unknown-time".to_string())
230        }
231        Err(_) => "unknown-time".to_string(),
232    }
233}
234
235pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
236    let default_rate = if cx.is_staff() {
237        DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
238    } else {
239        DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
240    };
241    let capture_rate = language::language_settings::all_language_settings(None, cx)
242        .edit_predictions
243        .example_capture_rate
244        .unwrap_or(default_rate);
245    cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
246        && rand::random::<u16>() % 10_000 < capture_rate
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use crate::EditPredictionStore;
253    use client::{Client, UserStore};
254    use clock::FakeSystemClock;
255    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
256    use indoc::indoc;
257    use language::{Anchor, Point};
258    use project::{FakeFs, Project};
259    use serde_json::json;
260    use settings::SettingsStore;
261    use std::path::Path;
262
263    #[gpui::test]
264    async fn test_capture_example(cx: &mut TestAppContext) {
265        init_test(cx);
266        let fs = FakeFs::new(cx.executor());
267
268        let committed_contents = indoc! {"
269            fn main() {
270                one();
271                two();
272                three();
273                four();
274                five();
275                six();
276                seven();
277                eight();
278                nine();
279            }
280        "};
281
282        let disk_contents = indoc! {"
283            fn main() {
284                // comment 1
285                one();
286                two();
287                three();
288                four();
289                five();
290                six();
291                seven();
292                eight();
293                // comment 2
294                nine();
295            }
296        "};
297
298        fs.insert_tree(
299            "/project",
300            json!({
301                ".git": {},
302                "src": {
303                    "main.rs": disk_contents,
304                }
305            }),
306        )
307        .await;
308
309        // Create an external file outside the main project
310        fs.insert_tree(
311            "/external",
312            json!({
313                "external.rs": "fn external() {}\n",
314            }),
315        )
316        .await;
317
318        fs.set_head_for_repo(
319            Path::new("/project/.git"),
320            &[("src/main.rs", committed_contents.to_string())],
321            "abc123def456",
322        );
323        fs.set_remote_for_repo(
324            Path::new("/project/.git"),
325            "origin",
326            "https://github.com/test/repo.git",
327        );
328
329        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
330
331        let buffer = project
332            .update(cx, |project, cx| {
333                project.open_local_buffer("/project/src/main.rs", cx)
334            })
335            .await
336            .unwrap();
337
338        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
339        ep_store.update(cx, |ep_store, cx| {
340            ep_store.register_buffer(&buffer, &project, cx)
341        });
342        cx.run_until_parked();
343
344        buffer.update(cx, |buffer, cx| {
345            let point = Point::new(6, 0);
346            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
347            let point = Point::new(4, 0);
348            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
349
350            pretty_assertions::assert_eq!(
351                buffer.text(),
352                indoc! {"
353                    fn main() {
354                        // comment 1
355                        one();
356                        two();
357                        // comment 4
358                        three();
359                        four();
360                        // comment 3
361                        five();
362                        six();
363                        seven();
364                        eight();
365                        // comment 2
366                        nine();
367                    }
368                "}
369            );
370        });
371        cx.run_until_parked();
372
373        // Open and edit an external file (outside the main project's worktree)
374        let external_buffer = project
375            .update(cx, |project, cx| {
376                project.open_local_buffer("/external/external.rs", cx)
377            })
378            .await
379            .unwrap();
380        ep_store.update(cx, |ep_store, cx| {
381            ep_store.register_buffer(&external_buffer, &project, cx)
382        });
383        cx.run_until_parked();
384        external_buffer.update(cx, |buffer, cx| {
385            let point = Point::new(0, 0);
386            buffer.edit([(point..point, "// external edit\n")], None, cx);
387        });
388        cx.run_until_parked();
389
390        // Verify the external edit was recorded in events
391        let events = ep_store.update(cx, |store, cx| {
392            store.edit_history_for_project_with_pause_split_last_event(&project, cx)
393        });
394        assert!(
395            matches!(
396                events
397                    .last()
398                    .unwrap()
399                    .event
400                    .as_ref(),
401                zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
402            ),
403            "external file edit should be in events"
404        );
405
406        let mut example = cx
407            .update(|cx| {
408                capture_example(
409                    project.clone(),
410                    buffer.clone(),
411                    Anchor::MIN,
412                    events,
413                    true,
414                    cx,
415                )
416                .unwrap()
417            })
418            .await
419            .unwrap();
420        example.name = "test".to_string();
421
422        pretty_assertions::assert_eq!(
423            example,
424            ExampleSpec {
425                name: "test".to_string(),
426                repository_url: "https://github.com/test/repo.git".to_string(),
427                revision: "abc123def456".to_string(),
428                tags: Vec::new(),
429                reasoning: None,
430                uncommitted_diff: indoc! {"
431                    --- a/src/main.rs
432                    +++ b/src/main.rs
433                    @@ -1,4 +1,5 @@
434                     fn main() {
435                    +    // comment 1
436                         one();
437                         two();
438                         three();
439                    @@ -7,5 +8,6 @@
440                         six();
441                         seven();
442                         eight();
443                    +    // comment 2
444                         nine();
445                     }
446                "}
447                .to_string(),
448                cursor_path: Path::new("src/main.rs").into(),
449                cursor_position: indoc! {"
450                    fn main() {
451                    ^[CURSOR_POSITION]
452                        // comment 1
453                        one();
454                        two();
455                        // comment 4
456                        three();
457                        four();
458                        // comment 3
459                        five();
460                        six();
461                        seven();
462                        eight();
463                        // comment 2
464                        nine();
465                    }
466                "}
467                .to_string(),
468                edit_history: indoc! {"
469                    --- a/src/main.rs
470                    +++ b/src/main.rs
471                    @@ -2,8 +2,10 @@
472                         // comment 1
473                         one();
474                         two();
475                    +    // comment 4
476                         three();
477                         four();
478                    +    // comment 3
479                         five();
480                         six();
481                         seven();
482                "}
483                .to_string(),
484                expected_patches: vec![
485                    indoc! {"
486                    --- a/src/main.rs
487                    +++ b/src/main.rs
488                    @@ -1,16 +1,16 @@
489                     fn main() {
490                         // comment 1
491                         one();
492                         two();
493                         // comment 4
494                         three();
495                         four();
496                         // comment 3
497                         five();
498                         six();
499                         seven();
500                         eight();
501                         // comment 2
502                         nine();
503                     }
504                "}
505                    .to_string()
506                ]
507            }
508        );
509    }
510
511    fn init_test(cx: &mut TestAppContext) {
512        cx.update(|cx| {
513            let settings_store = SettingsStore::test(cx);
514            cx.set_global(settings_store);
515            zlog::init_test();
516            let http_client = FakeHttpClient::with_404_response();
517            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
518            language_model::init(client.clone(), cx);
519            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
520            EditPredictionStore::global(&client, &user_store, cx);
521        })
522    }
523}