capture_example.rs

  1use crate::{
  2    EditPredictionExampleCaptureFeatureFlag, StoredEvent,
  3    cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
  4};
  5use anyhow::Result;
  6use buffer_diff::BufferDiffSnapshot;
  7use collections::HashMap;
  8use feature_flags::FeatureFlagAppExt as _;
  9use gpui::{App, Entity, Task};
 10use language::{Buffer, ToPoint as _};
 11use project::{Project, WorktreeId};
 12use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
 13use text::BufferSnapshot as TextBufferSnapshot;
 14
 15pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
 16
 17pub fn capture_example(
 18    project: Entity<Project>,
 19    buffer: Entity<Buffer>,
 20    cursor_anchor: language::Anchor,
 21    mut events: Vec<StoredEvent>,
 22    cx: &mut App,
 23) -> Option<Task<Result<ExampleSpec>>> {
 24    let snapshot = buffer.read(cx).snapshot();
 25    let file = snapshot.file()?;
 26    let worktree_id = file.worktree_id(cx);
 27    let repository = project.read(cx).active_repository(cx)?;
 28    let repository_snapshot = repository.read(cx).snapshot();
 29    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 30    let root_name = worktree.read(cx).root_name_str().to_owned();
 31    let cursor_path: Arc<Path> = file.path().as_std_path().into();
 32    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 33        return None;
 34    }
 35
 36    let repository_url = repository_snapshot
 37        .remote_origin_url
 38        .clone()
 39        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 40    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 41
 42    let git_store = project.read(cx).git_store().clone();
 43
 44    Some(cx.spawn(async move |mut cx| {
 45        let snapshots_by_path =
 46            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 47
 48        events.retain(|stored_event| {
 49            let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
 50            let relative_path = strip_root_name(path, &root_name);
 51            snapshots_by_path.contains_key(relative_path)
 52        });
 53
 54        let line_comment_prefix = snapshot
 55            .language()
 56            .and_then(|lang| lang.config().line_comments.first())
 57            .map(|s| s.to_string())
 58            .unwrap_or_default();
 59        let (cursor_excerpt, cursor_offset) = cx
 60            .background_executor()
 61            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 62            .await;
 63        let uncommitted_diff = cx
 64            .background_executor()
 65            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 66            .await;
 67
 68        let mut edit_history = String::new();
 69        for stored_event in &events {
 70            write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
 71            if !edit_history.ends_with('\n') {
 72                edit_history.push('\n');
 73            }
 74        }
 75
 76        let mut spec = ExampleSpec {
 77            name: generate_timestamp_name(),
 78            repository_url,
 79            revision,
 80            tags: Vec::new(),
 81            reasoning: None,
 82            uncommitted_diff,
 83            cursor_path,
 84            cursor_position: String::new(),
 85            edit_history,
 86            expected_patches: Vec::new(),
 87        };
 88        spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
 89        Ok(spec)
 90    }))
 91}
 92
 93fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
 94    path.strip_prefix(root_name).unwrap_or(path)
 95}
 96
 97fn write_event_with_relative_paths(
 98    output: &mut String,
 99    event: &zeta_prompt::Event,
100    root_name: &str,
101) {
102    fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
103        for component in strip_root_name(path, root_name).components() {
104            output.push('/');
105            write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
106        }
107    }
108
109    let zeta_prompt::Event::BufferChange {
110        path,
111        old_path,
112        diff,
113        ..
114    } = event;
115
116    output.push_str("--- a");
117    write_relative_path(output, old_path.as_ref(), root_name);
118    output.push_str("\n+++ b");
119    write_relative_path(output, path.as_ref(), root_name);
120    output.push('\n');
121    output.push_str(diff);
122}
123
124fn compute_cursor_excerpt(
125    snapshot: &language::BufferSnapshot,
126    cursor_anchor: language::Anchor,
127) -> (String, usize) {
128    use text::ToOffset as _;
129
130    let cursor_point = cursor_anchor.to_point(snapshot);
131    let (_editable_range, context_range) =
132        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
133    let context_start_offset = context_range.start.to_offset(snapshot);
134    let cursor_offset = cursor_anchor.to_offset(snapshot);
135    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
136    let excerpt = snapshot.text_for_range(context_range).collect::<String>();
137    (excerpt, cursor_offset_in_excerpt)
138}
139
140async fn collect_snapshots(
141    project: &Entity<Project>,
142    git_store: &Entity<project::git_store::GitStore>,
143    worktree_id: WorktreeId,
144    events: &[StoredEvent],
145    cx: &mut gpui::AsyncApp,
146) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
147    let mut snapshots_by_path = HashMap::default();
148    for stored_event in events {
149        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
150        if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
151            let project_path = project
152                .find_project_path(path, cx)
153                .filter(|path| path.worktree_id == worktree_id)?;
154            let relative_path: Arc<Path> = project_path.path.as_std_path().into();
155            Some((project_path, relative_path))
156        }) {
157            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
158                let buffer = project
159                    .update(cx, |project, cx| {
160                        project.open_buffer(project_path.clone(), cx)
161                    })
162                    .await?;
163                let diff = git_store
164                    .update(cx, |git_store, cx| {
165                        git_store.open_uncommitted_diff(buffer.clone(), cx)
166                    })
167                    .await?;
168                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
169                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
170            }
171        }
172    }
173    Ok(snapshots_by_path)
174}
175
176fn compute_uncommitted_diff(
177    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
178) -> String {
179    let mut uncommitted_diff = String::new();
180    for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
181        if let Some(head_text) = &diff_snapshot.base_text_string() {
182            let file_diff = language::unified_diff(head_text, &before_text.text());
183            if !file_diff.is_empty() {
184                let path_str = relative_path.to_string_lossy();
185                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
186                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
187                uncommitted_diff.push_str(&file_diff);
188                if !uncommitted_diff.ends_with('\n') {
189                    uncommitted_diff.push('\n');
190                }
191            }
192        }
193    }
194    uncommitted_diff
195}
196
197fn generate_timestamp_name() -> String {
198    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
199    match format {
200        Ok(format) => {
201            let now = time::OffsetDateTime::now_local()
202                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
203            now.format(&format)
204                .unwrap_or_else(|_| "unknown-time".to_string())
205        }
206        Err(_) => "unknown-time".to_string(),
207    }
208}
209
210pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
211    let capture_rate = language::language_settings::all_language_settings(None, cx)
212        .edit_predictions
213        .example_capture_rate
214        .unwrap_or(DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS);
215    cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
216        && rand::random::<u16>() % 10_000 < capture_rate
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::EditPredictionStore;
223    use client::{Client, UserStore};
224    use clock::FakeSystemClock;
225    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
226    use indoc::indoc;
227    use language::{Anchor, Point};
228    use project::{FakeFs, Project};
229    use serde_json::json;
230    use settings::SettingsStore;
231    use std::path::Path;
232
233    #[gpui::test]
234    async fn test_capture_example(cx: &mut TestAppContext) {
235        init_test(cx);
236        let fs = FakeFs::new(cx.executor());
237
238        let committed_contents = indoc! {"
239            fn main() {
240                one();
241                two();
242                three();
243                four();
244                five();
245                six();
246                seven();
247                eight();
248                nine();
249            }
250        "};
251
252        let disk_contents = indoc! {"
253            fn main() {
254                // comment 1
255                one();
256                two();
257                three();
258                four();
259                five();
260                six();
261                seven();
262                eight();
263                // comment 2
264                nine();
265            }
266        "};
267
268        fs.insert_tree(
269            "/project",
270            json!({
271                ".git": {},
272                "src": {
273                    "main.rs": disk_contents,
274                }
275            }),
276        )
277        .await;
278
279        // Create an external file outside the main project
280        fs.insert_tree(
281            "/external",
282            json!({
283                "external.rs": "fn external() {}\n",
284            }),
285        )
286        .await;
287
288        fs.set_head_for_repo(
289            Path::new("/project/.git"),
290            &[("src/main.rs", committed_contents.to_string())],
291            "abc123def456",
292        );
293        fs.set_remote_for_repo(
294            Path::new("/project/.git"),
295            "origin",
296            "https://github.com/test/repo.git",
297        );
298
299        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
300
301        let buffer = project
302            .update(cx, |project, cx| {
303                project.open_local_buffer("/project/src/main.rs", cx)
304            })
305            .await
306            .unwrap();
307
308        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
309        ep_store.update(cx, |ep_store, cx| {
310            ep_store.register_buffer(&buffer, &project, cx)
311        });
312        cx.run_until_parked();
313
314        buffer.update(cx, |buffer, cx| {
315            let point = Point::new(6, 0);
316            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
317            let point = Point::new(4, 0);
318            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
319
320            pretty_assertions::assert_eq!(
321                buffer.text(),
322                indoc! {"
323                    fn main() {
324                        // comment 1
325                        one();
326                        two();
327                        // comment 4
328                        three();
329                        four();
330                        // comment 3
331                        five();
332                        six();
333                        seven();
334                        eight();
335                        // comment 2
336                        nine();
337                    }
338                "}
339            );
340        });
341        cx.run_until_parked();
342
343        // Open and edit an external file (outside the main project's worktree)
344        let external_buffer = project
345            .update(cx, |project, cx| {
346                project.open_local_buffer("/external/external.rs", cx)
347            })
348            .await
349            .unwrap();
350        ep_store.update(cx, |ep_store, cx| {
351            ep_store.register_buffer(&external_buffer, &project, cx)
352        });
353        cx.run_until_parked();
354        external_buffer.update(cx, |buffer, cx| {
355            let point = Point::new(0, 0);
356            buffer.edit([(point..point, "// external edit\n")], None, cx);
357        });
358        cx.run_until_parked();
359
360        // Verify the external edit was recorded in events
361        let events = ep_store.update(cx, |store, cx| {
362            store.edit_history_for_project_with_pause_split_last_event(&project, cx)
363        });
364        assert!(
365            matches!(
366                events
367                    .last()
368                    .unwrap()
369                    .event
370                    .as_ref(),
371                zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
372            ),
373            "external file edit should be in events"
374        );
375
376        let mut example = cx
377            .update(|cx| {
378                capture_example(project.clone(), buffer.clone(), Anchor::MIN, events, cx).unwrap()
379            })
380            .await
381            .unwrap();
382        example.name = "test".to_string();
383
384        pretty_assertions::assert_eq!(
385            example,
386            ExampleSpec {
387                name: "test".to_string(),
388                repository_url: "https://github.com/test/repo.git".to_string(),
389                revision: "abc123def456".to_string(),
390                tags: Vec::new(),
391                reasoning: None,
392                uncommitted_diff: indoc! {"
393                    --- a/src/main.rs
394                    +++ b/src/main.rs
395                    @@ -1,4 +1,5 @@
396                     fn main() {
397                    +    // comment 1
398                         one();
399                         two();
400                         three();
401                    @@ -7,5 +8,6 @@
402                         six();
403                         seven();
404                         eight();
405                    +    // comment 2
406                         nine();
407                     }
408                "}
409                .to_string(),
410                cursor_path: Path::new("src/main.rs").into(),
411                cursor_position: indoc! {"
412                    fn main() {
413                    ^[CURSOR_POSITION]
414                        // comment 1
415                        one();
416                        two();
417                        // comment 4
418                        three();
419                        four();
420                        // comment 3
421                        five();
422                        six();
423                        seven();
424                        eight();
425                        // comment 2
426                        nine();
427                    }
428                "}
429                .to_string(),
430                edit_history: indoc! {"
431                    --- a/src/main.rs
432                    +++ b/src/main.rs
433                    @@ -2,8 +2,10 @@
434                         // comment 1
435                         one();
436                         two();
437                    +    // comment 4
438                         three();
439                         four();
440                    +    // comment 3
441                         five();
442                         six();
443                         seven();
444                "}
445                .to_string(),
446                expected_patches: Vec::new()
447            }
448        );
449    }
450
451    fn init_test(cx: &mut TestAppContext) {
452        cx.update(|cx| {
453            let settings_store = SettingsStore::test(cx);
454            cx.set_global(settings_store);
455            zlog::init_test();
456            let http_client = FakeHttpClient::with_404_response();
457            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
458            language_model::init(client.clone(), cx);
459            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
460            EditPredictionStore::global(&client, &user_store, cx);
461        })
462    }
463}