capture_example.rs

  1use crate::{
  2    EditPredictionExampleCaptureFeatureFlag, StoredEvent,
  3    cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
  4};
  5use anyhow::Result;
  6use buffer_diff::BufferDiffSnapshot;
  7use collections::HashMap;
  8use feature_flags::FeatureFlagAppExt as _;
  9use gpui::{App, Entity, Task};
 10use language::{Buffer, ToPoint as _};
 11use project::{Project, WorktreeId};
 12use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
 13use text::BufferSnapshot as TextBufferSnapshot;
 14
 15pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
 16
 17pub fn capture_example(
 18    project: Entity<Project>,
 19    buffer: Entity<Buffer>,
 20    cursor_anchor: language::Anchor,
 21    mut events: Vec<StoredEvent>,
 22    cx: &mut App,
 23) -> Option<Task<Result<ExampleSpec>>> {
 24    let snapshot = buffer.read(cx).snapshot();
 25    let file = snapshot.file()?;
 26    let worktree_id = file.worktree_id(cx);
 27    let repository = project.read(cx).active_repository(cx)?;
 28    let repository_snapshot = repository.read(cx).snapshot();
 29    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 30    let cursor_path = worktree.read(cx).root_name().join(file.path());
 31    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 32        return None;
 33    }
 34
 35    let repository_url = repository_snapshot
 36        .remote_origin_url
 37        .clone()
 38        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 39    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 40
 41    let git_store = project.read(cx).git_store().clone();
 42
 43    Some(cx.spawn(async move |mut cx| {
 44        let snapshots_by_path =
 45            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 46
 47        events.retain(|stored_event| {
 48            match stored_event.event.as_ref() {
 49                zeta_prompt::Event::BufferChange { path, .. } => {
 50                    if !snapshots_by_path.contains_key(path) {
 51                        return false;
 52                    }
 53                }
 54            }
 55            true
 56        });
 57
 58        let line_comment_prefix = snapshot
 59            .language()
 60            .and_then(|lang| lang.config().line_comments.first())
 61            .map(|s| s.to_string())
 62            .unwrap_or_default();
 63        let (cursor_excerpt, cursor_offset) = cx
 64            .background_executor()
 65            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 66            .await;
 67        let uncommitted_diff = cx
 68            .background_executor()
 69            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 70            .await;
 71
 72        let mut edit_history = String::new();
 73        for stored_event in &events {
 74            zeta_prompt::write_event(&mut edit_history, &stored_event.event);
 75            if !edit_history.ends_with('\n') {
 76                edit_history.push('\n');
 77            }
 78        }
 79
 80        let mut spec = ExampleSpec {
 81            name: generate_timestamp_name(),
 82            repository_url,
 83            revision,
 84            tags: Vec::new(),
 85            reasoning: None,
 86            uncommitted_diff,
 87            cursor_path: cursor_path.as_std_path().into(),
 88            cursor_position: String::new(),
 89            edit_history,
 90            expected_patches: Vec::new(),
 91        };
 92        spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
 93        Ok(spec)
 94    }))
 95}
 96
 97fn compute_cursor_excerpt(
 98    snapshot: &language::BufferSnapshot,
 99    cursor_anchor: language::Anchor,
100) -> (String, usize) {
101    use text::ToOffset as _;
102
103    let cursor_point = cursor_anchor.to_point(snapshot);
104    let (_editable_range, context_range) =
105        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
106    let context_start_offset = context_range.start.to_offset(snapshot);
107    let cursor_offset = cursor_anchor.to_offset(snapshot);
108    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
109    let excerpt = snapshot.text_for_range(context_range).collect::<String>();
110    (excerpt, cursor_offset_in_excerpt)
111}
112
113async fn collect_snapshots(
114    project: &Entity<Project>,
115    git_store: &Entity<project::git_store::GitStore>,
116    worktree_id: WorktreeId,
117    events: &[StoredEvent],
118    cx: &mut gpui::AsyncApp,
119) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
120    let mut snapshots_by_path = HashMap::default();
121    let root_name = project.read_with(cx, |project, cx| {
122        project
123            .worktree_for_id(worktree_id, cx)
124            .unwrap()
125            .read(cx)
126            .root_name()
127            .to_owned()
128    })?;
129    for stored_event in events {
130        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
131        if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
132            let project_path = project
133                .find_project_path(path, cx)
134                .filter(|path| path.worktree_id == worktree_id)?;
135            let full_path = root_name.join(&project_path.path).as_std_path().into();
136            Some((project_path, full_path))
137        })? {
138            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
139                let buffer = project
140                    .update(cx, |project, cx| {
141                        project.open_buffer(project_path.clone(), cx)
142                    })?
143                    .await?;
144                let diff = git_store
145                    .update(cx, |git_store, cx| {
146                        git_store.open_uncommitted_diff(buffer.clone(), cx)
147                    })?
148                    .await?;
149                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx))?;
150                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
151            }
152        }
153    }
154    Ok(snapshots_by_path)
155}
156
157fn compute_uncommitted_diff(
158    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
159) -> String {
160    let mut uncommitted_diff = String::new();
161    for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
162        if let Some(head_text) = &diff_snapshot.base_text_string() {
163            let file_diff = language::unified_diff(head_text, &before_text.text());
164            if !file_diff.is_empty() {
165                let path_str = full_path.to_string_lossy();
166                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
167                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
168                uncommitted_diff.push_str(&file_diff);
169                if !uncommitted_diff.ends_with('\n') {
170                    uncommitted_diff.push('\n');
171                }
172            }
173        }
174    }
175    uncommitted_diff
176}
177
178fn generate_timestamp_name() -> String {
179    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
180    match format {
181        Ok(format) => {
182            let now = time::OffsetDateTime::now_local()
183                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
184            now.format(&format)
185                .unwrap_or_else(|_| "unknown-time".to_string())
186        }
187        Err(_) => "unknown-time".to_string(),
188    }
189}
190
191pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
192    let capture_rate = language::language_settings::all_language_settings(None, cx)
193        .edit_predictions
194        .example_capture_rate
195        .unwrap_or(DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS);
196    cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
197        && rand::random::<u16>() % 10_000 < capture_rate
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::EditPredictionStore;
204    use client::{Client, UserStore};
205    use clock::FakeSystemClock;
206    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
207    use indoc::indoc;
208    use language::{Anchor, Point};
209    use project::{FakeFs, Project};
210    use serde_json::json;
211    use settings::SettingsStore;
212    use std::path::Path;
213
214    #[gpui::test]
215    async fn test_capture_example(cx: &mut TestAppContext) {
216        init_test(cx);
217        let fs = FakeFs::new(cx.executor());
218
219        let committed_contents = indoc! {"
220            fn main() {
221                one();
222                two();
223                three();
224                four();
225                five();
226                six();
227                seven();
228                eight();
229                nine();
230            }
231        "};
232
233        let disk_contents = indoc! {"
234            fn main() {
235                // comment 1
236                one();
237                two();
238                three();
239                four();
240                five();
241                six();
242                seven();
243                eight();
244                // comment 2
245                nine();
246            }
247        "};
248
249        fs.insert_tree(
250            "/project",
251            json!({
252                ".git": {},
253                "src": {
254                    "main.rs": disk_contents,
255                }
256            }),
257        )
258        .await;
259
260        fs.set_head_for_repo(
261            Path::new("/project/.git"),
262            &[("src/main.rs", committed_contents.to_string())],
263            "abc123def456",
264        );
265        fs.set_remote_for_repo(
266            Path::new("/project/.git"),
267            "origin",
268            "https://github.com/test/repo.git",
269        );
270
271        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
272
273        let buffer = project
274            .update(cx, |project, cx| {
275                project.open_local_buffer("/project/src/main.rs", cx)
276            })
277            .await
278            .unwrap();
279
280        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
281        ep_store.update(cx, |ep_store, cx| {
282            ep_store.register_buffer(&buffer, &project, cx)
283        });
284        cx.run_until_parked();
285
286        buffer.update(cx, |buffer, cx| {
287            let point = Point::new(6, 0);
288            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
289            let point = Point::new(4, 0);
290            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
291
292            pretty_assertions::assert_eq!(
293                buffer.text(),
294                indoc! {"
295                    fn main() {
296                        // comment 1
297                        one();
298                        two();
299                        // comment 4
300                        three();
301                        four();
302                        // comment 3
303                        five();
304                        six();
305                        seven();
306                        eight();
307                        // comment 2
308                        nine();
309                    }
310                "}
311            );
312        });
313        cx.run_until_parked();
314
315        let events = ep_store.update(cx, |store, cx| {
316            store.edit_history_for_project_with_pause_split_last_event(&project, cx)
317        });
318        let mut example = cx
319            .update(|cx| {
320                capture_example(project.clone(), buffer.clone(), Anchor::MIN, events, cx).unwrap()
321            })
322            .await
323            .unwrap();
324        example.name = "test".to_string();
325
326        pretty_assertions::assert_eq!(
327            example,
328            ExampleSpec {
329                name: "test".to_string(),
330                repository_url: "https://github.com/test/repo.git".to_string(),
331                revision: "abc123def456".to_string(),
332                tags: Vec::new(),
333                reasoning: None,
334                uncommitted_diff: indoc! {"
335                    --- a/project/src/main.rs
336                    +++ b/project/src/main.rs
337                    @@ -1,4 +1,5 @@
338                     fn main() {
339                    +    // comment 1
340                         one();
341                         two();
342                         three();
343                    @@ -7,5 +8,6 @@
344                         six();
345                         seven();
346                         eight();
347                    +    // comment 2
348                         nine();
349                     }
350                "}
351                .to_string(),
352                cursor_path: Path::new("project/src/main.rs").into(),
353                cursor_position: indoc! {"
354                    fn main() {
355                    ^[CURSOR_POSITION]
356                        // comment 1
357                        one();
358                        two();
359                        // comment 4
360                        three();
361                        four();
362                        // comment 3
363                        five();
364                        six();
365                        seven();
366                        eight();
367                        // comment 2
368                        nine();
369                    }
370                "}
371                .to_string(),
372                edit_history: indoc! {"
373                    --- a/project/src/main.rs
374                    +++ b/project/src/main.rs
375                    @@ -2,8 +2,10 @@
376                         // comment 1
377                         one();
378                         two();
379                    +    // comment 4
380                         three();
381                         four();
382                    +    // comment 3
383                         five();
384                         six();
385                         seven();
386                "}
387                .to_string(),
388                expected_patches: Vec::new()
389            }
390        );
391    }
392
393    fn init_test(cx: &mut TestAppContext) {
394        cx.update(|cx| {
395            let settings_store = SettingsStore::test(cx);
396            cx.set_global(settings_store);
397            zlog::init_test();
398            let http_client = FakeHttpClient::with_404_response();
399            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
400            language_model::init(client.clone(), cx);
401            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
402            EditPredictionStore::global(&client, &user_store, cx);
403        })
404    }
405}