capture_example.rs

  1use crate::{
  2    EditPredictionExampleCaptureFeatureFlag, StoredEvent,
  3    cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
  4};
  5use anyhow::Result;
  6use buffer_diff::BufferDiffSnapshot;
  7use collections::HashMap;
  8use feature_flags::FeatureFlagAppExt as _;
  9use gpui::{App, Entity, Task};
 10use language::{Buffer, ToPoint as _};
 11use project::{Project, WorktreeId};
 12use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
 13use text::{BufferSnapshot as TextBufferSnapshot, Point};
 14
 15pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
 16
 17pub fn capture_example(
 18    project: Entity<Project>,
 19    buffer: Entity<Buffer>,
 20    cursor_anchor: language::Anchor,
 21    mut events: Vec<StoredEvent>,
 22    populate_expected_patch: bool,
 23    cx: &mut App,
 24) -> Option<Task<Result<ExampleSpec>>> {
 25    let snapshot = buffer.read(cx).snapshot();
 26    let file = snapshot.file()?;
 27    let worktree_id = file.worktree_id(cx);
 28    let repository = project.read(cx).active_repository(cx)?;
 29    let repository_snapshot = repository.read(cx).snapshot();
 30    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 31    let root_name = worktree.read(cx).root_name_str().to_owned();
 32    let cursor_path: Arc<Path> = file.path().as_std_path().into();
 33    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 34        return None;
 35    }
 36
 37    let repository_url = repository_snapshot
 38        .remote_origin_url
 39        .clone()
 40        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 41    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 42
 43    let git_store = project.read(cx).git_store().clone();
 44
 45    Some(cx.spawn(async move |mut cx| {
 46        let snapshots_by_path =
 47            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 48
 49        events.retain(|stored_event| {
 50            let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
 51            let relative_path = strip_root_name(path, &root_name);
 52            snapshots_by_path.contains_key(relative_path)
 53        });
 54
 55        let line_comment_prefix = snapshot
 56            .language()
 57            .and_then(|lang| lang.config().line_comments.first())
 58            .map(|s| s.to_string())
 59            .unwrap_or_default();
 60        let (cursor_excerpt, cursor_offset, cursor_excerpt_range) = cx
 61            .background_executor()
 62            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 63            .await;
 64        let uncommitted_diff = cx
 65            .background_executor()
 66            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 67            .await;
 68
 69        let mut edit_history = String::new();
 70        for stored_event in &events {
 71            write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
 72            if !edit_history.ends_with('\n') {
 73                edit_history.push('\n');
 74            }
 75        }
 76
 77        // Initialize an empty patch with context lines, to make it easy
 78        // to write the expected patch by hand.
 79        let mut expected_patches = Vec::new();
 80        if populate_expected_patch {
 81            let mut empty_patch = String::new();
 82            let start_row = cursor_excerpt_range.start.row + 1;
 83            let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
 84            writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
 85            writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
 86            writeln!(
 87                &mut empty_patch,
 88                "@@ -{},{} +{},{} @@",
 89                start_row, row_count, start_row, row_count,
 90            )
 91            .ok();
 92            for line in cursor_excerpt.lines() {
 93                writeln!(&mut empty_patch, " {}", line).ok();
 94            }
 95            expected_patches.push(empty_patch);
 96        }
 97
 98        let mut spec = ExampleSpec {
 99            name: generate_timestamp_name(),
100            repository_url,
101            revision,
102            tags: Vec::new(),
103            reasoning: None,
104            uncommitted_diff,
105            cursor_path,
106            cursor_position: String::new(),
107            edit_history,
108            expected_patches,
109        };
110        spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
111        Ok(spec)
112    }))
113}
114
115fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
116    path.strip_prefix(root_name).unwrap_or(path)
117}
118
119fn write_event_with_relative_paths(
120    output: &mut String,
121    event: &zeta_prompt::Event,
122    root_name: &str,
123) {
124    fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
125        for component in strip_root_name(path, root_name).components() {
126            output.push('/');
127            write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
128        }
129    }
130
131    let zeta_prompt::Event::BufferChange {
132        path,
133        old_path,
134        diff,
135        ..
136    } = event;
137
138    output.push_str("--- a");
139    write_relative_path(output, old_path.as_ref(), root_name);
140    output.push_str("\n+++ b");
141    write_relative_path(output, path.as_ref(), root_name);
142    output.push('\n');
143    output.push_str(diff);
144}
145
146fn compute_cursor_excerpt(
147    snapshot: &language::BufferSnapshot,
148    cursor_anchor: language::Anchor,
149) -> (String, usize, Range<Point>) {
150    use text::ToOffset as _;
151
152    let cursor_point = cursor_anchor.to_point(snapshot);
153    let (_editable_range, context_range) =
154        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
155    let context_start_offset = context_range.start.to_offset(snapshot);
156    let cursor_offset = cursor_anchor.to_offset(snapshot);
157    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
158    let excerpt = snapshot
159        .text_for_range(context_range.clone())
160        .collect::<String>();
161    (excerpt, cursor_offset_in_excerpt, context_range)
162}
163
164async fn collect_snapshots(
165    project: &Entity<Project>,
166    git_store: &Entity<project::git_store::GitStore>,
167    worktree_id: WorktreeId,
168    events: &[StoredEvent],
169    cx: &mut gpui::AsyncApp,
170) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
171    let mut snapshots_by_path = HashMap::default();
172    for stored_event in events {
173        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
174        if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
175            let project_path = project
176                .find_project_path(path, cx)
177                .filter(|path| path.worktree_id == worktree_id)?;
178            let relative_path: Arc<Path> = project_path.path.as_std_path().into();
179            Some((project_path, relative_path))
180        }) {
181            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
182                let buffer = project
183                    .update(cx, |project, cx| {
184                        project.open_buffer(project_path.clone(), cx)
185                    })
186                    .await?;
187                let diff = git_store
188                    .update(cx, |git_store, cx| {
189                        git_store.open_uncommitted_diff(buffer.clone(), cx)
190                    })
191                    .await?;
192                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
193                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
194            }
195        }
196    }
197    Ok(snapshots_by_path)
198}
199
200fn compute_uncommitted_diff(
201    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
202) -> String {
203    let mut uncommitted_diff = String::new();
204    for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
205        if let Some(head_text) = &diff_snapshot.base_text_string() {
206            let file_diff = language::unified_diff(head_text, &before_text.text());
207            if !file_diff.is_empty() {
208                let path_str = relative_path.to_string_lossy();
209                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
210                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
211                uncommitted_diff.push_str(&file_diff);
212                if !uncommitted_diff.ends_with('\n') {
213                    uncommitted_diff.push('\n');
214                }
215            }
216        }
217    }
218    uncommitted_diff
219}
220
221fn generate_timestamp_name() -> String {
222    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
223    match format {
224        Ok(format) => {
225            let now = time::OffsetDateTime::now_local()
226                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
227            now.format(&format)
228                .unwrap_or_else(|_| "unknown-time".to_string())
229        }
230        Err(_) => "unknown-time".to_string(),
231    }
232}
233
234pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
235    let capture_rate = language::language_settings::all_language_settings(None, cx)
236        .edit_predictions
237        .example_capture_rate
238        .unwrap_or(DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS);
239    cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
240        && rand::random::<u16>() % 10_000 < capture_rate
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use crate::EditPredictionStore;
247    use client::{Client, UserStore};
248    use clock::FakeSystemClock;
249    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
250    use indoc::indoc;
251    use language::{Anchor, Point};
252    use project::{FakeFs, Project};
253    use serde_json::json;
254    use settings::SettingsStore;
255    use std::path::Path;
256
257    #[gpui::test]
258    async fn test_capture_example(cx: &mut TestAppContext) {
259        init_test(cx);
260        let fs = FakeFs::new(cx.executor());
261
262        let committed_contents = indoc! {"
263            fn main() {
264                one();
265                two();
266                three();
267                four();
268                five();
269                six();
270                seven();
271                eight();
272                nine();
273            }
274        "};
275
276        let disk_contents = indoc! {"
277            fn main() {
278                // comment 1
279                one();
280                two();
281                three();
282                four();
283                five();
284                six();
285                seven();
286                eight();
287                // comment 2
288                nine();
289            }
290        "};
291
292        fs.insert_tree(
293            "/project",
294            json!({
295                ".git": {},
296                "src": {
297                    "main.rs": disk_contents,
298                }
299            }),
300        )
301        .await;
302
303        // Create an external file outside the main project
304        fs.insert_tree(
305            "/external",
306            json!({
307                "external.rs": "fn external() {}\n",
308            }),
309        )
310        .await;
311
312        fs.set_head_for_repo(
313            Path::new("/project/.git"),
314            &[("src/main.rs", committed_contents.to_string())],
315            "abc123def456",
316        );
317        fs.set_remote_for_repo(
318            Path::new("/project/.git"),
319            "origin",
320            "https://github.com/test/repo.git",
321        );
322
323        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
324
325        let buffer = project
326            .update(cx, |project, cx| {
327                project.open_local_buffer("/project/src/main.rs", cx)
328            })
329            .await
330            .unwrap();
331
332        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
333        ep_store.update(cx, |ep_store, cx| {
334            ep_store.register_buffer(&buffer, &project, cx)
335        });
336        cx.run_until_parked();
337
338        buffer.update(cx, |buffer, cx| {
339            let point = Point::new(6, 0);
340            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
341            let point = Point::new(4, 0);
342            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
343
344            pretty_assertions::assert_eq!(
345                buffer.text(),
346                indoc! {"
347                    fn main() {
348                        // comment 1
349                        one();
350                        two();
351                        // comment 4
352                        three();
353                        four();
354                        // comment 3
355                        five();
356                        six();
357                        seven();
358                        eight();
359                        // comment 2
360                        nine();
361                    }
362                "}
363            );
364        });
365        cx.run_until_parked();
366
367        // Open and edit an external file (outside the main project's worktree)
368        let external_buffer = project
369            .update(cx, |project, cx| {
370                project.open_local_buffer("/external/external.rs", cx)
371            })
372            .await
373            .unwrap();
374        ep_store.update(cx, |ep_store, cx| {
375            ep_store.register_buffer(&external_buffer, &project, cx)
376        });
377        cx.run_until_parked();
378        external_buffer.update(cx, |buffer, cx| {
379            let point = Point::new(0, 0);
380            buffer.edit([(point..point, "// external edit\n")], None, cx);
381        });
382        cx.run_until_parked();
383
384        // Verify the external edit was recorded in events
385        let events = ep_store.update(cx, |store, cx| {
386            store.edit_history_for_project_with_pause_split_last_event(&project, cx)
387        });
388        assert!(
389            matches!(
390                events
391                    .last()
392                    .unwrap()
393                    .event
394                    .as_ref(),
395                zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
396            ),
397            "external file edit should be in events"
398        );
399
400        let mut example = cx
401            .update(|cx| {
402                capture_example(
403                    project.clone(),
404                    buffer.clone(),
405                    Anchor::MIN,
406                    events,
407                    true,
408                    cx,
409                )
410                .unwrap()
411            })
412            .await
413            .unwrap();
414        example.name = "test".to_string();
415
416        pretty_assertions::assert_eq!(
417            example,
418            ExampleSpec {
419                name: "test".to_string(),
420                repository_url: "https://github.com/test/repo.git".to_string(),
421                revision: "abc123def456".to_string(),
422                tags: Vec::new(),
423                reasoning: None,
424                uncommitted_diff: indoc! {"
425                    --- a/src/main.rs
426                    +++ b/src/main.rs
427                    @@ -1,4 +1,5 @@
428                     fn main() {
429                    +    // comment 1
430                         one();
431                         two();
432                         three();
433                    @@ -7,5 +8,6 @@
434                         six();
435                         seven();
436                         eight();
437                    +    // comment 2
438                         nine();
439                     }
440                "}
441                .to_string(),
442                cursor_path: Path::new("src/main.rs").into(),
443                cursor_position: indoc! {"
444                    fn main() {
445                    ^[CURSOR_POSITION]
446                        // comment 1
447                        one();
448                        two();
449                        // comment 4
450                        three();
451                        four();
452                        // comment 3
453                        five();
454                        six();
455                        seven();
456                        eight();
457                        // comment 2
458                        nine();
459                    }
460                "}
461                .to_string(),
462                edit_history: indoc! {"
463                    --- a/src/main.rs
464                    +++ b/src/main.rs
465                    @@ -2,8 +2,10 @@
466                         // comment 1
467                         one();
468                         two();
469                    +    // comment 4
470                         three();
471                         four();
472                    +    // comment 3
473                         five();
474                         six();
475                         seven();
476                "}
477                .to_string(),
478                expected_patches: vec![
479                    indoc! {"
480                    --- a/src/main.rs
481                    +++ b/src/main.rs
482                    @@ -1,16 +1,16 @@
483                     fn main() {
484                         // comment 1
485                         one();
486                         two();
487                         // comment 4
488                         three();
489                         four();
490                         // comment 3
491                         five();
492                         six();
493                         seven();
494                         eight();
495                         // comment 2
496                         nine();
497                     }
498                "}
499                    .to_string()
500                ]
501            }
502        );
503    }
504
505    fn init_test(cx: &mut TestAppContext) {
506        cx.update(|cx| {
507            let settings_store = SettingsStore::test(cx);
508            cx.set_global(settings_store);
509            zlog::init_test();
510            let http_client = FakeHttpClient::with_404_response();
511            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
512            language_model::init(client.clone(), cx);
513            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
514            EditPredictionStore::global(&client, &user_store, cx);
515        })
516    }
517}