capture_example.rs

  1use crate::{
  2    EditPredictionExampleCaptureFeatureFlag, StoredEvent,
  3    cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
  4};
  5use anyhow::Result;
  6use buffer_diff::BufferDiffSnapshot;
  7use collections::HashMap;
  8use feature_flags::FeatureFlagAppExt as _;
  9use gpui::{App, Entity, Task};
 10use language::{Buffer, ToPoint as _};
 11use project::{Project, WorktreeId};
 12use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
 13use text::{BufferSnapshot as TextBufferSnapshot, Point};
 14
 15pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
 16pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
 17pub(crate) const ZETA2_TESTING_RATE_PER_10K_PREDICTION: u16 = 500;
 18
 19pub fn capture_example(
 20    project: Entity<Project>,
 21    buffer: Entity<Buffer>,
 22    cursor_anchor: language::Anchor,
 23    mut events: Vec<StoredEvent>,
 24    populate_expected_patch: bool,
 25    cx: &mut App,
 26) -> Option<Task<Result<ExampleSpec>>> {
 27    let snapshot = buffer.read(cx).snapshot();
 28    let file = snapshot.file()?;
 29    let worktree_id = file.worktree_id(cx);
 30    let repository = project.read(cx).active_repository(cx)?;
 31    let repository_snapshot = repository.read(cx).snapshot();
 32    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 33    let root_name = worktree.read(cx).root_name_str().to_owned();
 34    let cursor_path: Arc<Path> = file.path().as_std_path().into();
 35    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 36        return None;
 37    }
 38
 39    let repository_url = repository_snapshot
 40        .remote_origin_url
 41        .clone()
 42        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 43    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 44
 45    let git_store = project.read(cx).git_store().clone();
 46
 47    Some(cx.spawn(async move |mut cx| {
 48        let snapshots_by_path =
 49            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 50
 51        events.retain(|stored_event| {
 52            let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
 53            let relative_path = strip_root_name(path, &root_name);
 54            snapshots_by_path.contains_key(relative_path)
 55        });
 56
 57        let line_comment_prefix = snapshot
 58            .language()
 59            .and_then(|lang| lang.config().line_comments.first())
 60            .map(|s| s.to_string())
 61            .unwrap_or_default();
 62        let (cursor_excerpt, cursor_offset, cursor_excerpt_range) = cx
 63            .background_executor()
 64            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 65            .await;
 66        let uncommitted_diff = cx
 67            .background_executor()
 68            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 69            .await;
 70
 71        let mut edit_history = String::new();
 72        for stored_event in &events {
 73            write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
 74            if !edit_history.ends_with('\n') {
 75                edit_history.push('\n');
 76            }
 77        }
 78
 79        // Initialize an empty patch with context lines, to make it easy
 80        // to write the expected patch by hand.
 81        let mut expected_patches = Vec::new();
 82        let mut rejected_patch = None;
 83        if populate_expected_patch {
 84            let mut empty_patch = String::new();
 85            let start_row = cursor_excerpt_range.start.row + 1;
 86            let row_count = cursor_excerpt_range.end.row - cursor_excerpt_range.start.row + 1;
 87            writeln!(&mut empty_patch, "--- a/{}", cursor_path.display()).ok();
 88            writeln!(&mut empty_patch, "+++ b/{}", cursor_path.display()).ok();
 89            writeln!(
 90                &mut empty_patch,
 91                "@@ -{},{} +{},{} @@",
 92                start_row, row_count, start_row, row_count,
 93            )
 94            .ok();
 95            for line in cursor_excerpt.lines() {
 96                writeln!(&mut empty_patch, " {}", line).ok();
 97            }
 98
 99            expected_patches.push(empty_patch.clone());
100            rejected_patch = Some(empty_patch);
101        }
102
103        let mut spec = ExampleSpec {
104            name: generate_timestamp_name(),
105            repository_url,
106            revision,
107            tags: Vec::new(),
108            reasoning: None,
109            uncommitted_diff,
110            cursor_path,
111            cursor_position: String::new(),
112            edit_history,
113            expected_patches,
114            rejected_patch,
115        };
116        spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
117        Ok(spec)
118    }))
119}
120
121fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
122    path.strip_prefix(root_name).unwrap_or(path)
123}
124
125fn write_event_with_relative_paths(
126    output: &mut String,
127    event: &zeta_prompt::Event,
128    root_name: &str,
129) {
130    fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
131        for component in strip_root_name(path, root_name).components() {
132            output.push('/');
133            write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
134        }
135    }
136
137    let zeta_prompt::Event::BufferChange {
138        path,
139        old_path,
140        diff,
141        ..
142    } = event;
143
144    output.push_str("--- a");
145    write_relative_path(output, old_path.as_ref(), root_name);
146    output.push_str("\n+++ b");
147    write_relative_path(output, path.as_ref(), root_name);
148    output.push('\n');
149    output.push_str(diff);
150}
151
152fn compute_cursor_excerpt(
153    snapshot: &language::BufferSnapshot,
154    cursor_anchor: language::Anchor,
155) -> (String, usize, Range<Point>) {
156    use text::ToOffset as _;
157
158    let cursor_point = cursor_anchor.to_point(snapshot);
159    let (_editable_range, context_range) =
160        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
161    let context_start_offset = context_range.start.to_offset(snapshot);
162    let cursor_offset = cursor_anchor.to_offset(snapshot);
163    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
164    let excerpt = snapshot
165        .text_for_range(context_range.clone())
166        .collect::<String>();
167    (excerpt, cursor_offset_in_excerpt, context_range)
168}
169
170async fn collect_snapshots(
171    project: &Entity<Project>,
172    git_store: &Entity<project::git_store::GitStore>,
173    worktree_id: WorktreeId,
174    events: &[StoredEvent],
175    cx: &mut gpui::AsyncApp,
176) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
177    let mut snapshots_by_path = HashMap::default();
178    for stored_event in events {
179        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
180        if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
181            let project_path = project
182                .find_project_path(path, cx)
183                .filter(|path| path.worktree_id == worktree_id)?;
184            let relative_path: Arc<Path> = project_path.path.as_std_path().into();
185            Some((project_path, relative_path))
186        }) {
187            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
188                let buffer = project
189                    .update(cx, |project, cx| {
190                        project.open_buffer(project_path.clone(), cx)
191                    })
192                    .await?;
193                let diff = git_store
194                    .update(cx, |git_store, cx| {
195                        git_store.open_uncommitted_diff(buffer.clone(), cx)
196                    })
197                    .await?;
198                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
199                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
200            }
201        }
202    }
203    Ok(snapshots_by_path)
204}
205
206fn compute_uncommitted_diff(
207    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
208) -> String {
209    let mut uncommitted_diff = String::new();
210    for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
211        if let Some(head_text) = &diff_snapshot.base_text_string() {
212            let file_diff = language::unified_diff(head_text, &before_text.text());
213            if !file_diff.is_empty() {
214                let path_str = relative_path.to_string_lossy();
215                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
216                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
217                uncommitted_diff.push_str(&file_diff);
218                if !uncommitted_diff.ends_with('\n') {
219                    uncommitted_diff.push('\n');
220                }
221            }
222        }
223    }
224    uncommitted_diff
225}
226
227fn generate_timestamp_name() -> String {
228    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
229    match format {
230        Ok(format) => {
231            let now = time::OffsetDateTime::now_local()
232                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
233            now.format(&format)
234                .unwrap_or_else(|_| "unknown-time".to_string())
235        }
236        Err(_) => "unknown-time".to_string(),
237    }
238}
239
240pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
241    let default_rate = if cx.is_staff() {
242        DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
243    } else {
244        DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS
245    };
246    let capture_rate = language::language_settings::all_language_settings(None, cx)
247        .edit_predictions
248        .example_capture_rate
249        .unwrap_or(default_rate);
250    cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
251        && rand::random::<u16>() % 10_000 < capture_rate
252}
253
254pub(crate) fn should_send_testing_zeta2_request() -> bool {
255    rand::random::<u16>() % 10_000 < ZETA2_TESTING_RATE_PER_10K_PREDICTION
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use crate::EditPredictionStore;
262    use client::{Client, UserStore};
263    use clock::FakeSystemClock;
264    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
265    use indoc::indoc;
266    use language::{Anchor, Point};
267    use project::{FakeFs, Project};
268    use serde_json::json;
269    use settings::SettingsStore;
270    use std::path::Path;
271
272    #[gpui::test]
273    async fn test_capture_example(cx: &mut TestAppContext) {
274        init_test(cx);
275        let fs = FakeFs::new(cx.executor());
276
277        let committed_contents = indoc! {"
278            fn main() {
279                one();
280                two();
281                three();
282                four();
283                five();
284                six();
285                seven();
286                eight();
287                nine();
288            }
289        "};
290
291        let disk_contents = indoc! {"
292            fn main() {
293                // comment 1
294                one();
295                two();
296                three();
297                four();
298                five();
299                six();
300                seven();
301                eight();
302                // comment 2
303                nine();
304            }
305        "};
306
307        fs.insert_tree(
308            "/project",
309            json!({
310                ".git": {},
311                "src": {
312                    "main.rs": disk_contents,
313                }
314            }),
315        )
316        .await;
317
318        // Create an external file outside the main project
319        fs.insert_tree(
320            "/external",
321            json!({
322                "external.rs": "fn external() {}\n",
323            }),
324        )
325        .await;
326
327        fs.set_head_for_repo(
328            Path::new("/project/.git"),
329            &[("src/main.rs", committed_contents.to_string())],
330            "abc123def456",
331        );
332        fs.set_remote_for_repo(
333            Path::new("/project/.git"),
334            "origin",
335            "https://github.com/test/repo.git",
336        );
337
338        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
339
340        let buffer = project
341            .update(cx, |project, cx| {
342                project.open_local_buffer("/project/src/main.rs", cx)
343            })
344            .await
345            .unwrap();
346
347        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
348        ep_store.update(cx, |ep_store, cx| {
349            ep_store.register_buffer(&buffer, &project, cx)
350        });
351        cx.run_until_parked();
352
353        buffer.update(cx, |buffer, cx| {
354            let point = Point::new(6, 0);
355            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
356            let point = Point::new(4, 0);
357            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
358
359            pretty_assertions::assert_eq!(
360                buffer.text(),
361                indoc! {"
362                    fn main() {
363                        // comment 1
364                        one();
365                        two();
366                        // comment 4
367                        three();
368                        four();
369                        // comment 3
370                        five();
371                        six();
372                        seven();
373                        eight();
374                        // comment 2
375                        nine();
376                    }
377                "}
378            );
379        });
380        cx.run_until_parked();
381
382        // Open and edit an external file (outside the main project's worktree)
383        let external_buffer = project
384            .update(cx, |project, cx| {
385                project.open_local_buffer("/external/external.rs", cx)
386            })
387            .await
388            .unwrap();
389        ep_store.update(cx, |ep_store, cx| {
390            ep_store.register_buffer(&external_buffer, &project, cx)
391        });
392        cx.run_until_parked();
393        external_buffer.update(cx, |buffer, cx| {
394            let point = Point::new(0, 0);
395            buffer.edit([(point..point, "// external edit\n")], None, cx);
396        });
397        cx.run_until_parked();
398
399        // Verify the external edit was recorded in events
400        let events = ep_store.update(cx, |store, cx| {
401            store.edit_history_for_project_with_pause_split_last_event(&project, cx)
402        });
403        assert!(
404            matches!(
405                events
406                    .last()
407                    .unwrap()
408                    .event
409                    .as_ref(),
410                zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
411            ),
412            "external file edit should be in events"
413        );
414
415        let mut example = cx
416            .update(|cx| {
417                capture_example(
418                    project.clone(),
419                    buffer.clone(),
420                    Anchor::MIN,
421                    events,
422                    true,
423                    cx,
424                )
425                .unwrap()
426            })
427            .await
428            .unwrap();
429        example.name = "test".to_string();
430
431        pretty_assertions::assert_eq!(
432            example,
433            ExampleSpec {
434                name: "test".to_string(),
435                repository_url: "https://github.com/test/repo.git".to_string(),
436                revision: "abc123def456".to_string(),
437                tags: Vec::new(),
438                reasoning: None,
439                uncommitted_diff: indoc! {"
440                    --- a/src/main.rs
441                    +++ b/src/main.rs
442                    @@ -1,4 +1,5 @@
443                     fn main() {
444                    +    // comment 1
445                         one();
446                         two();
447                         three();
448                    @@ -7,5 +8,6 @@
449                         six();
450                         seven();
451                         eight();
452                    +    // comment 2
453                         nine();
454                     }
455                "}
456                .to_string(),
457                cursor_path: Path::new("src/main.rs").into(),
458                cursor_position: indoc! {"
459                    fn main() {
460                    ^[CURSOR_POSITION]
461                        // comment 1
462                        one();
463                        two();
464                        // comment 4
465                        three();
466                        four();
467                        // comment 3
468                        five();
469                        six();
470                        seven();
471                        eight();
472                        // comment 2
473                        nine();
474                    }
475                "}
476                .to_string(),
477                edit_history: indoc! {"
478                    --- a/src/main.rs
479                    +++ b/src/main.rs
480                    @@ -2,8 +2,10 @@
481                         // comment 1
482                         one();
483                         two();
484                    +    // comment 4
485                         three();
486                         four();
487                    +    // comment 3
488                         five();
489                         six();
490                         seven();
491                "}
492                .to_string(),
493                expected_patches: vec![
494                    indoc! {"
495                        --- a/src/main.rs
496                        +++ b/src/main.rs
497                        @@ -1,16 +1,16 @@
498                         fn main() {
499                             // comment 1
500                             one();
501                             two();
502                             // comment 4
503                             three();
504                             four();
505                             // comment 3
506                             five();
507                             six();
508                             seven();
509                             eight();
510                             // comment 2
511                             nine();
512                         }
513                    "}
514                    .to_string()
515                ],
516                rejected_patch: Some(
517                    indoc! {"
518                        --- a/src/main.rs
519                        +++ b/src/main.rs
520                        @@ -1,16 +1,16 @@
521                         fn main() {
522                             // comment 1
523                             one();
524                             two();
525                             // comment 4
526                             three();
527                             four();
528                             // comment 3
529                             five();
530                             six();
531                             seven();
532                             eight();
533                             // comment 2
534                             nine();
535                         }
536                    "}
537                    .to_string()
538                )
539            }
540        );
541    }
542
543    fn init_test(cx: &mut TestAppContext) {
544        cx.update(|cx| {
545            let settings_store = SettingsStore::test(cx);
546            cx.set_global(settings_store);
547            zlog::init_test();
548            let http_client = FakeHttpClient::with_404_response();
549            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
550            language_model::init(client.clone(), cx);
551            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
552            EditPredictionStore::global(&client, &user_store, cx);
553        })
554    }
555}