capture_example.rs

  1use crate::{
  2    EditPredictionExampleCaptureFeatureFlag, StoredEvent,
  3    cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
  4};
  5use anyhow::Result;
  6use buffer_diff::BufferDiffSnapshot;
  7use collections::HashMap;
  8use feature_flags::FeatureFlagAppExt as _;
  9use gpui::{App, Entity, Task};
 10use language::{Buffer, ToPoint as _};
 11use project::{Project, WorktreeId};
 12use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
 13use text::{BufferSnapshot as TextBufferSnapshot, Point};
 14
 15pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
 16pub(crate) const ZETA2_TESTING_RATE_PER_10K_PREDICTION: u16 = 500;
 17
 18pub fn capture_example(
 19    project: Entity<Project>,
 20    buffer: Entity<Buffer>,
 21    cursor_anchor: language::Anchor,
 22    mut events: Vec<StoredEvent>,
 23    populate_expected_patch: bool,
 24    cx: &mut App,
 25) -> Option<Task<Result<ExampleSpec>>> {
 26    let snapshot = buffer.read(cx).snapshot();
 27    let file = snapshot.file()?;
 28    let worktree_id = file.worktree_id(cx);
 29    let repository = project.read(cx).active_repository(cx)?;
 30    let repository_snapshot = repository.read(cx).snapshot();
 31    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 32    let root_name = worktree.read(cx).root_name_str().to_owned();
 33    let cursor_path: Arc<Path> = file.path().as_std_path().into();
 34    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 35        return None;
 36    }
 37
 38    let repository_url = repository_snapshot
 39        .remote_origin_url
 40        .clone()
 41        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 42    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 43
 44    let git_store = project.read(cx).git_store().clone();
 45
 46    Some(cx.spawn(async move |mut cx| {
 47        let snapshots_by_path =
 48            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 49
 50        events.retain(|stored_event| {
 51            let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
 52            let relative_path = strip_root_name(path, &root_name);
 53            snapshots_by_path.contains_key(relative_path)
 54        });
 55
 56        let line_comment_prefix = snapshot
 57            .language()
 58            .and_then(|lang| lang.config().line_comments.first())
 59            .map(|s| s.to_string())
 60            .unwrap_or_default();
 61        let (cursor_excerpt, cursor_offset, cursor_excerpt_range) = cx
 62            .background_executor()
 63            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 64            .await;
 65        let uncommitted_diff = cx
 66            .background_executor()
 67            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 68            .await;
 69
 70        let mut edit_history = String::new();
 71        for stored_event in &events {
 72            write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
 73            if !edit_history.ends_with('\n') {
 74                edit_history.push('\n');
 75            }
 76        }
 77
 78        // Initialize an empty patch with context lines, to make it easy
 79        // to write the expected patch by hand.
 80        let mut expected_patches = Vec::new();
 81        if populate_expected_patch {
 82            let mut empty_patch = String::new();
 83            let start_row = cursor_excerpt_range.start.row + 1;
 84            let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
 85            writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
 86            writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
 87            writeln!(
 88                &mut empty_patch,
 89                "@@ -{},{} +{},{} @@",
 90                start_row, row_count, start_row, row_count,
 91            )
 92            .ok();
 93            for line in cursor_excerpt.lines() {
 94                writeln!(&mut empty_patch, " {}", line).ok();
 95            }
 96            expected_patches.push(empty_patch);
 97        }
 98
 99        let mut spec = ExampleSpec {
100            name: generate_timestamp_name(),
101            repository_url,
102            revision,
103            tags: Vec::new(),
104            reasoning: None,
105            uncommitted_diff,
106            cursor_path,
107            cursor_position: String::new(),
108            edit_history,
109            expected_patches,
110        };
111        spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
112        Ok(spec)
113    }))
114}
115
116fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
117    path.strip_prefix(root_name).unwrap_or(path)
118}
119
120fn write_event_with_relative_paths(
121    output: &mut String,
122    event: &zeta_prompt::Event,
123    root_name: &str,
124) {
125    fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
126        for component in strip_root_name(path, root_name).components() {
127            output.push('/');
128            write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
129        }
130    }
131
132    let zeta_prompt::Event::BufferChange {
133        path,
134        old_path,
135        diff,
136        ..
137    } = event;
138
139    output.push_str("--- a");
140    write_relative_path(output, old_path.as_ref(), root_name);
141    output.push_str("\n+++ b");
142    write_relative_path(output, path.as_ref(), root_name);
143    output.push('\n');
144    output.push_str(diff);
145}
146
147fn compute_cursor_excerpt(
148    snapshot: &language::BufferSnapshot,
149    cursor_anchor: language::Anchor,
150) -> (String, usize, Range<Point>) {
151    use text::ToOffset as _;
152
153    let cursor_point = cursor_anchor.to_point(snapshot);
154    let (_editable_range, context_range) =
155        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
156    let context_start_offset = context_range.start.to_offset(snapshot);
157    let cursor_offset = cursor_anchor.to_offset(snapshot);
158    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
159    let excerpt = snapshot
160        .text_for_range(context_range.clone())
161        .collect::<String>();
162    (excerpt, cursor_offset_in_excerpt, context_range)
163}
164
165async fn collect_snapshots(
166    project: &Entity<Project>,
167    git_store: &Entity<project::git_store::GitStore>,
168    worktree_id: WorktreeId,
169    events: &[StoredEvent],
170    cx: &mut gpui::AsyncApp,
171) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
172    let mut snapshots_by_path = HashMap::default();
173    for stored_event in events {
174        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
175        if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
176            let project_path = project
177                .find_project_path(path, cx)
178                .filter(|path| path.worktree_id == worktree_id)?;
179            let relative_path: Arc<Path> = project_path.path.as_std_path().into();
180            Some((project_path, relative_path))
181        }) {
182            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
183                let buffer = project
184                    .update(cx, |project, cx| {
185                        project.open_buffer(project_path.clone(), cx)
186                    })
187                    .await?;
188                let diff = git_store
189                    .update(cx, |git_store, cx| {
190                        git_store.open_uncommitted_diff(buffer.clone(), cx)
191                    })
192                    .await?;
193                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
194                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
195            }
196        }
197    }
198    Ok(snapshots_by_path)
199}
200
201fn compute_uncommitted_diff(
202    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
203) -> String {
204    let mut uncommitted_diff = String::new();
205    for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
206        if let Some(head_text) = &diff_snapshot.base_text_string() {
207            let file_diff = language::unified_diff(head_text, &before_text.text());
208            if !file_diff.is_empty() {
209                let path_str = relative_path.to_string_lossy();
210                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
211                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
212                uncommitted_diff.push_str(&file_diff);
213                if !uncommitted_diff.ends_with('\n') {
214                    uncommitted_diff.push('\n');
215                }
216            }
217        }
218    }
219    uncommitted_diff
220}
221
222fn generate_timestamp_name() -> String {
223    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
224    match format {
225        Ok(format) => {
226            let now = time::OffsetDateTime::now_local()
227                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
228            now.format(&format)
229                .unwrap_or_else(|_| "unknown-time".to_string())
230        }
231        Err(_) => "unknown-time".to_string(),
232    }
233}
234
235pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
236    let capture_rate = language::language_settings::all_language_settings(None, cx)
237        .edit_predictions
238        .example_capture_rate
239        .unwrap_or(DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS);
240    cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
241        && rand::random::<u16>() % 10_000 < capture_rate
242}
243
244pub(crate) fn should_send_testing_zeta2_request() -> bool {
245    rand::random::<u16>() % 10_000 < ZETA2_TESTING_RATE_PER_10K_PREDICTION
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use crate::EditPredictionStore;
252    use client::{Client, UserStore};
253    use clock::FakeSystemClock;
254    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
255    use indoc::indoc;
256    use language::{Anchor, Point};
257    use project::{FakeFs, Project};
258    use serde_json::json;
259    use settings::SettingsStore;
260    use std::path::Path;
261
262    #[gpui::test]
263    async fn test_capture_example(cx: &mut TestAppContext) {
264        init_test(cx);
265        let fs = FakeFs::new(cx.executor());
266
267        let committed_contents = indoc! {"
268            fn main() {
269                one();
270                two();
271                three();
272                four();
273                five();
274                six();
275                seven();
276                eight();
277                nine();
278            }
279        "};
280
281        let disk_contents = indoc! {"
282            fn main() {
283                // comment 1
284                one();
285                two();
286                three();
287                four();
288                five();
289                six();
290                seven();
291                eight();
292                // comment 2
293                nine();
294            }
295        "};
296
297        fs.insert_tree(
298            "/project",
299            json!({
300                ".git": {},
301                "src": {
302                    "main.rs": disk_contents,
303                }
304            }),
305        )
306        .await;
307
308        // Create an external file outside the main project
309        fs.insert_tree(
310            "/external",
311            json!({
312                "external.rs": "fn external() {}\n",
313            }),
314        )
315        .await;
316
317        fs.set_head_for_repo(
318            Path::new("/project/.git"),
319            &[("src/main.rs", committed_contents.to_string())],
320            "abc123def456",
321        );
322        fs.set_remote_for_repo(
323            Path::new("/project/.git"),
324            "origin",
325            "https://github.com/test/repo.git",
326        );
327
328        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
329
330        let buffer = project
331            .update(cx, |project, cx| {
332                project.open_local_buffer("/project/src/main.rs", cx)
333            })
334            .await
335            .unwrap();
336
337        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
338        ep_store.update(cx, |ep_store, cx| {
339            ep_store.register_buffer(&buffer, &project, cx)
340        });
341        cx.run_until_parked();
342
343        buffer.update(cx, |buffer, cx| {
344            let point = Point::new(6, 0);
345            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
346            let point = Point::new(4, 0);
347            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
348
349            pretty_assertions::assert_eq!(
350                buffer.text(),
351                indoc! {"
352                    fn main() {
353                        // comment 1
354                        one();
355                        two();
356                        // comment 4
357                        three();
358                        four();
359                        // comment 3
360                        five();
361                        six();
362                        seven();
363                        eight();
364                        // comment 2
365                        nine();
366                    }
367                "}
368            );
369        });
370        cx.run_until_parked();
371
372        // Open and edit an external file (outside the main project's worktree)
373        let external_buffer = project
374            .update(cx, |project, cx| {
375                project.open_local_buffer("/external/external.rs", cx)
376            })
377            .await
378            .unwrap();
379        ep_store.update(cx, |ep_store, cx| {
380            ep_store.register_buffer(&external_buffer, &project, cx)
381        });
382        cx.run_until_parked();
383        external_buffer.update(cx, |buffer, cx| {
384            let point = Point::new(0, 0);
385            buffer.edit([(point..point, "// external edit\n")], None, cx);
386        });
387        cx.run_until_parked();
388
389        // Verify the external edit was recorded in events
390        let events = ep_store.update(cx, |store, cx| {
391            store.edit_history_for_project_with_pause_split_last_event(&project, cx)
392        });
393        assert!(
394            matches!(
395                events
396                    .last()
397                    .unwrap()
398                    .event
399                    .as_ref(),
400                zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
401            ),
402            "external file edit should be in events"
403        );
404
405        let mut example = cx
406            .update(|cx| {
407                capture_example(
408                    project.clone(),
409                    buffer.clone(),
410                    Anchor::MIN,
411                    events,
412                    true,
413                    cx,
414                )
415                .unwrap()
416            })
417            .await
418            .unwrap();
419        example.name = "test".to_string();
420
421        pretty_assertions::assert_eq!(
422            example,
423            ExampleSpec {
424                name: "test".to_string(),
425                repository_url: "https://github.com/test/repo.git".to_string(),
426                revision: "abc123def456".to_string(),
427                tags: Vec::new(),
428                reasoning: None,
429                uncommitted_diff: indoc! {"
430                    --- a/src/main.rs
431                    +++ b/src/main.rs
432                    @@ -1,4 +1,5 @@
433                     fn main() {
434                    +    // comment 1
435                         one();
436                         two();
437                         three();
438                    @@ -7,5 +8,6 @@
439                         six();
440                         seven();
441                         eight();
442                    +    // comment 2
443                         nine();
444                     }
445                "}
446                .to_string(),
447                cursor_path: Path::new("src/main.rs").into(),
448                cursor_position: indoc! {"
449                    fn main() {
450                    ^[CURSOR_POSITION]
451                        // comment 1
452                        one();
453                        two();
454                        // comment 4
455                        three();
456                        four();
457                        // comment 3
458                        five();
459                        six();
460                        seven();
461                        eight();
462                        // comment 2
463                        nine();
464                    }
465                "}
466                .to_string(),
467                edit_history: indoc! {"
468                    --- a/src/main.rs
469                    +++ b/src/main.rs
470                    @@ -2,8 +2,10 @@
471                         // comment 1
472                         one();
473                         two();
474                    +    // comment 4
475                         three();
476                         four();
477                    +    // comment 3
478                         five();
479                         six();
480                         seven();
481                "}
482                .to_string(),
483                expected_patches: vec![
484                    indoc! {"
485                    --- a/src/main.rs
486                    +++ b/src/main.rs
487                    @@ -1,16 +1,16 @@
488                     fn main() {
489                         // comment 1
490                         one();
491                         two();
492                         // comment 4
493                         three();
494                         four();
495                         // comment 3
496                         five();
497                         six();
498                         seven();
499                         eight();
500                         // comment 2
501                         nine();
502                     }
503                "}
504                    .to_string()
505                ]
506            }
507        );
508    }
509
510    fn init_test(cx: &mut TestAppContext) {
511        cx.update(|cx| {
512            let settings_store = SettingsStore::test(cx);
513            cx.set_global(settings_store);
514            zlog::init_test();
515            let http_client = FakeHttpClient::with_404_response();
516            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
517            language_model::init(client.clone(), cx);
518            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
519            EditPredictionStore::global(&client, &user_store, cx);
520        })
521    }
522}