capture_example.rs

  1use crate::{
  2    EditPredictionExampleCaptureFeatureFlag, EditPredictionStore, StoredEvent,
  3    cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
  4};
  5use anyhow::Result;
  6use buffer_diff::BufferDiffSnapshot;
  7use collections::HashMap;
  8use feature_flags::FeatureFlagAppExt as _;
  9use gpui::{App, Entity, Task};
 10use language::{Buffer, ToPoint as _};
 11use project::{Project, WorktreeId};
 12use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
 13use text::BufferSnapshot as TextBufferSnapshot;
 14
 15pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
 16
 17pub fn capture_example(
 18    project: Entity<Project>,
 19    buffer: Entity<Buffer>,
 20    cursor_anchor: language::Anchor,
 21    cx: &mut App,
 22) -> Option<Task<Result<ExampleSpec>>> {
 23    let ep_store = EditPredictionStore::try_global(cx)?;
 24    let snapshot = buffer.read(cx).snapshot();
 25    let file = snapshot.file()?;
 26    let worktree_id = file.worktree_id(cx);
 27    let repository = project.read(cx).active_repository(cx)?;
 28    let repository_snapshot = repository.read(cx).snapshot();
 29    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 30    let cursor_path = worktree.read(cx).root_name().join(file.path());
 31    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 32        return None;
 33    }
 34
 35    let repository_url = repository_snapshot
 36        .remote_origin_url
 37        .clone()
 38        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 39    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 40
 41    let mut events = ep_store.update(cx, |store, cx| {
 42        store.edit_history_for_project_with_pause_split_last_event(&project, cx)
 43    });
 44
 45    let git_store = project.read(cx).git_store().clone();
 46
 47    Some(cx.spawn(async move |mut cx| {
 48        let snapshots_by_path =
 49            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 50
 51        events.retain(|stored_event| {
 52            match stored_event.event.as_ref() {
 53                zeta_prompt::Event::BufferChange { path, .. } => {
 54                    if !snapshots_by_path.contains_key(path) {
 55                        return false;
 56                    }
 57                }
 58            }
 59            true
 60        });
 61
 62        let line_comment_prefix = snapshot
 63            .language()
 64            .and_then(|lang| lang.config().line_comments.first())
 65            .map(|s| s.to_string())
 66            .unwrap_or_default();
 67        let (cursor_excerpt, cursor_offset) = cx
 68            .background_executor()
 69            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 70            .await;
 71        let uncommitted_diff = cx
 72            .background_executor()
 73            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 74            .await;
 75
 76        let mut edit_history = String::new();
 77        for stored_event in &events {
 78            zeta_prompt::write_event(&mut edit_history, &stored_event.event);
 79            if !edit_history.ends_with('\n') {
 80                edit_history.push('\n');
 81            }
 82        }
 83
 84        let mut spec = ExampleSpec {
 85            name: generate_timestamp_name(),
 86            repository_url,
 87            revision,
 88            tags: Vec::new(),
 89            reasoning: None,
 90            uncommitted_diff,
 91            cursor_path: cursor_path.as_std_path().into(),
 92            cursor_position: String::new(),
 93            edit_history,
 94            expected_patches: Vec::new(),
 95        };
 96        spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
 97        Ok(spec)
 98    }))
 99}
100
101fn compute_cursor_excerpt(
102    snapshot: &language::BufferSnapshot,
103    cursor_anchor: language::Anchor,
104) -> (String, usize) {
105    use text::ToOffset as _;
106
107    let cursor_point = cursor_anchor.to_point(snapshot);
108    let (_editable_range, context_range) =
109        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
110    let context_start_offset = context_range.start.to_offset(snapshot);
111    let cursor_offset = cursor_anchor.to_offset(snapshot);
112    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
113    let excerpt = snapshot.text_for_range(context_range).collect::<String>();
114    (excerpt, cursor_offset_in_excerpt)
115}
116
117async fn collect_snapshots(
118    project: &Entity<Project>,
119    git_store: &Entity<project::git_store::GitStore>,
120    worktree_id: WorktreeId,
121    events: &[StoredEvent],
122    cx: &mut gpui::AsyncApp,
123) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
124    let mut snapshots_by_path = HashMap::default();
125    let root_name = project.read_with(cx, |project, cx| {
126        project
127            .worktree_for_id(worktree_id, cx)
128            .unwrap()
129            .read(cx)
130            .root_name()
131            .to_owned()
132    })?;
133    for stored_event in events {
134        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
135        if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
136            let project_path = project
137                .find_project_path(path, cx)
138                .filter(|path| path.worktree_id == worktree_id)?;
139            let full_path = root_name.join(&project_path.path).as_std_path().into();
140            Some((project_path, full_path))
141        })? {
142            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
143                let buffer = project
144                    .update(cx, |project, cx| {
145                        project.open_buffer(project_path.clone(), cx)
146                    })?
147                    .await?;
148                let diff = git_store
149                    .update(cx, |git_store, cx| {
150                        git_store.open_uncommitted_diff(buffer.clone(), cx)
151                    })?
152                    .await?;
153                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx))?;
154                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
155            }
156        }
157    }
158    Ok(snapshots_by_path)
159}
160
161fn compute_uncommitted_diff(
162    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
163) -> String {
164    let mut uncommitted_diff = String::new();
165    for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
166        if let Some(head_text) = &diff_snapshot.base_text_string() {
167            let file_diff = language::unified_diff(head_text, &before_text.text());
168            if !file_diff.is_empty() {
169                let path_str = full_path.to_string_lossy();
170                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
171                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
172                uncommitted_diff.push_str(&file_diff);
173                if !uncommitted_diff.ends_with('\n') {
174                    uncommitted_diff.push('\n');
175                }
176            }
177        }
178    }
179    uncommitted_diff
180}
181
182fn generate_timestamp_name() -> String {
183    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
184    match format {
185        Ok(format) => {
186            let now = time::OffsetDateTime::now_local()
187                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
188            now.format(&format)
189                .unwrap_or_else(|_| "unknown-time".to_string())
190        }
191        Err(_) => "unknown-time".to_string(),
192    }
193}
194
195pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
196    let capture_rate = language::language_settings::all_language_settings(None, cx)
197        .edit_predictions
198        .example_capture_rate
199        .unwrap_or(DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS);
200    cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
201        && rand::random::<u16>() % 10_000 < capture_rate
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use client::{Client, UserStore};
208    use clock::FakeSystemClock;
209    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
210    use indoc::indoc;
211    use language::{Anchor, Point};
212    use project::{FakeFs, Project};
213    use serde_json::json;
214    use settings::SettingsStore;
215    use std::path::Path;
216
217    #[gpui::test]
218    async fn test_capture_example(cx: &mut TestAppContext) {
219        init_test(cx);
220        let fs = FakeFs::new(cx.executor());
221
222        let committed_contents = indoc! {"
223            fn main() {
224                one();
225                two();
226                three();
227                four();
228                five();
229                six();
230                seven();
231                eight();
232                nine();
233            }
234        "};
235
236        let disk_contents = indoc! {"
237            fn main() {
238                // comment 1
239                one();
240                two();
241                three();
242                four();
243                five();
244                six();
245                seven();
246                eight();
247                // comment 2
248                nine();
249            }
250        "};
251
252        fs.insert_tree(
253            "/project",
254            json!({
255                ".git": {},
256                "src": {
257                    "main.rs": disk_contents,
258                }
259            }),
260        )
261        .await;
262
263        fs.set_head_for_repo(
264            Path::new("/project/.git"),
265            &[("src/main.rs", committed_contents.to_string())],
266            "abc123def456",
267        );
268        fs.set_remote_for_repo(
269            Path::new("/project/.git"),
270            "origin",
271            "https://github.com/test/repo.git",
272        );
273
274        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
275
276        let buffer = project
277            .update(cx, |project, cx| {
278                project.open_local_buffer("/project/src/main.rs", cx)
279            })
280            .await
281            .unwrap();
282
283        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
284        ep_store.update(cx, |ep_store, cx| {
285            ep_store.register_buffer(&buffer, &project, cx)
286        });
287        cx.run_until_parked();
288
289        buffer.update(cx, |buffer, cx| {
290            let point = Point::new(6, 0);
291            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
292            let point = Point::new(4, 0);
293            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
294
295            pretty_assertions::assert_eq!(
296                buffer.text(),
297                indoc! {"
298                    fn main() {
299                        // comment 1
300                        one();
301                        two();
302                        // comment 4
303                        three();
304                        four();
305                        // comment 3
306                        five();
307                        six();
308                        seven();
309                        eight();
310                        // comment 2
311                        nine();
312                    }
313                "}
314            );
315        });
316        cx.run_until_parked();
317
318        let mut example = cx
319            .update(|cx| capture_example(project.clone(), buffer.clone(), Anchor::MIN, cx).unwrap())
320            .await
321            .unwrap();
322        example.name = "test".to_string();
323
324        pretty_assertions::assert_eq!(
325            example,
326            ExampleSpec {
327                name: "test".to_string(),
328                repository_url: "https://github.com/test/repo.git".to_string(),
329                revision: "abc123def456".to_string(),
330                tags: Vec::new(),
331                reasoning: None,
332                uncommitted_diff: indoc! {"
333                    --- a/project/src/main.rs
334                    +++ b/project/src/main.rs
335                    @@ -1,4 +1,5 @@
336                     fn main() {
337                    +    // comment 1
338                         one();
339                         two();
340                         three();
341                    @@ -7,5 +8,6 @@
342                         six();
343                         seven();
344                         eight();
345                    +    // comment 2
346                         nine();
347                     }
348                "}
349                .to_string(),
350                cursor_path: Path::new("project/src/main.rs").into(),
351                cursor_position: indoc! {"
352                    fn main() {
353                    ^[CURSOR_POSITION]
354                        // comment 1
355                        one();
356                        two();
357                        // comment 4
358                        three();
359                        four();
360                        // comment 3
361                        five();
362                        six();
363                        seven();
364                        eight();
365                        // comment 2
366                        nine();
367                    }
368                "}
369                .to_string(),
370                edit_history: indoc! {"
371                    --- a/project/src/main.rs
372                    +++ b/project/src/main.rs
373                    @@ -2,8 +2,10 @@
374                         // comment 1
375                         one();
376                         two();
377                    +    // comment 4
378                         three();
379                         four();
380                    +    // comment 3
381                         five();
382                         six();
383                         seven();
384                "}
385                .to_string(),
386                expected_patches: Vec::new()
387            }
388        );
389    }
390
391    fn init_test(cx: &mut TestAppContext) {
392        cx.update(|cx| {
393            let settings_store = SettingsStore::test(cx);
394            cx.set_global(settings_store);
395            zlog::init_test();
396            let http_client = FakeHttpClient::with_404_response();
397            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
398            language_model::init(client.clone(), cx);
399            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
400            EditPredictionStore::global(&client, &user_store, cx);
401        })
402    }
403}