capture_example.rs

  1use crate::{
  2    EditPredictionStore, StoredEvent,
  3    cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
  4};
  5use anyhow::Result;
  6use buffer_diff::BufferDiffSnapshot;
  7use collections::HashMap;
  8use gpui::{App, Entity, Task};
  9use language::{Buffer, ToPoint as _};
 10use project::{Project, WorktreeId};
 11use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
 12use text::BufferSnapshot as TextBufferSnapshot;
 13
 14pub fn capture_example(
 15    project: Entity<Project>,
 16    buffer: Entity<Buffer>,
 17    cursor_anchor: language::Anchor,
 18    cx: &mut App,
 19) -> Option<Task<Result<ExampleSpec>>> {
 20    let ep_store = EditPredictionStore::try_global(cx)?;
 21    let snapshot = buffer.read(cx).snapshot();
 22    let file = snapshot.file()?;
 23    let worktree_id = file.worktree_id(cx);
 24    let repository = project.read(cx).active_repository(cx)?;
 25    let repository_snapshot = repository.read(cx).snapshot();
 26    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 27    let cursor_path = worktree.read(cx).root_name().join(file.path());
 28    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 29        return None;
 30    }
 31
 32    let repository_url = repository_snapshot
 33        .remote_origin_url
 34        .clone()
 35        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 36    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 37
 38    let mut events = ep_store.update(cx, |store, cx| {
 39        store.edit_history_for_project_with_pause_split_last_event(&project, cx)
 40    });
 41
 42    let git_store = project.read(cx).git_store().clone();
 43
 44    Some(cx.spawn(async move |mut cx| {
 45        let snapshots_by_path =
 46            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
 47
 48        events.retain(|stored_event| {
 49            match stored_event.event.as_ref() {
 50                zeta_prompt::Event::BufferChange { path, .. } => {
 51                    if !snapshots_by_path.contains_key(path) {
 52                        return false;
 53                    }
 54                }
 55            }
 56            true
 57        });
 58
 59        let line_comment_prefix = snapshot
 60            .language()
 61            .and_then(|lang| lang.config().line_comments.first())
 62            .map(|s| s.to_string())
 63            .unwrap_or_default();
 64        let (cursor_excerpt, cursor_offset) = cx
 65            .background_executor()
 66            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 67            .await;
 68        let uncommitted_diff = cx
 69            .background_executor()
 70            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 71            .await;
 72
 73        let mut edit_history = String::new();
 74        for stored_event in &events {
 75            zeta_prompt::write_event(&mut edit_history, &stored_event.event);
 76            if !edit_history.ends_with('\n') {
 77                edit_history.push('\n');
 78            }
 79        }
 80
 81        let mut spec = ExampleSpec {
 82            name: generate_timestamp_name(),
 83            repository_url,
 84            revision,
 85            tags: Vec::new(),
 86            reasoning: None,
 87            uncommitted_diff,
 88            cursor_path: cursor_path.as_std_path().into(),
 89            cursor_position: String::new(),
 90            edit_history,
 91            expected_patches: Vec::new(),
 92        };
 93        spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
 94        Ok(spec)
 95    }))
 96}
 97
 98fn compute_cursor_excerpt(
 99    snapshot: &language::BufferSnapshot,
100    cursor_anchor: language::Anchor,
101) -> (String, usize) {
102    use text::ToOffset as _;
103
104    let cursor_point = cursor_anchor.to_point(snapshot);
105    let (_editable_range, context_range) =
106        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
107    let context_start_offset = context_range.start.to_offset(snapshot);
108    let cursor_offset = cursor_anchor.to_offset(snapshot);
109    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
110    let excerpt = snapshot.text_for_range(context_range).collect::<String>();
111    (excerpt, cursor_offset_in_excerpt)
112}
113
114async fn collect_snapshots(
115    project: &Entity<Project>,
116    git_store: &Entity<project::git_store::GitStore>,
117    worktree_id: WorktreeId,
118    events: &[StoredEvent],
119    cx: &mut gpui::AsyncApp,
120) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
121    let mut snapshots_by_path = HashMap::default();
122    let root_name = project.read_with(cx, |project, cx| {
123        project
124            .worktree_for_id(worktree_id, cx)
125            .unwrap()
126            .read(cx)
127            .root_name()
128            .to_owned()
129    })?;
130    for stored_event in events {
131        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
132        if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
133            let project_path = project
134                .find_project_path(path, cx)
135                .filter(|path| path.worktree_id == worktree_id)?;
136            let full_path = root_name.join(&project_path.path).as_std_path().into();
137            Some((project_path, full_path))
138        })? {
139            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
140                let buffer = project
141                    .update(cx, |project, cx| {
142                        project.open_buffer(project_path.clone(), cx)
143                    })?
144                    .await?;
145                let diff = git_store
146                    .update(cx, |git_store, cx| {
147                        git_store.open_uncommitted_diff(buffer.clone(), cx)
148                    })?
149                    .await?;
150                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx))?;
151                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
152            }
153        }
154    }
155    Ok(snapshots_by_path)
156}
157
158fn compute_uncommitted_diff(
159    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
160) -> String {
161    let mut uncommitted_diff = String::new();
162    for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
163        if let Some(head_text) = &diff_snapshot.base_text_string() {
164            let file_diff = language::unified_diff(head_text, &before_text.text());
165            if !file_diff.is_empty() {
166                let path_str = full_path.to_string_lossy();
167                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
168                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
169                uncommitted_diff.push_str(&file_diff);
170                if !uncommitted_diff.ends_with('\n') {
171                    uncommitted_diff.push('\n');
172                }
173            }
174        }
175    }
176    uncommitted_diff
177}
178
179fn generate_timestamp_name() -> String {
180    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
181    match format {
182        Ok(format) => {
183            let now = time::OffsetDateTime::now_local()
184                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
185            now.format(&format)
186                .unwrap_or_else(|_| "unknown-time".to_string())
187        }
188        Err(_) => "unknown-time".to_string(),
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use client::{Client, UserStore};
196    use clock::FakeSystemClock;
197    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
198    use indoc::indoc;
199    use language::{Anchor, Point};
200    use project::{FakeFs, Project};
201    use serde_json::json;
202    use settings::SettingsStore;
203    use std::path::Path;
204
205    #[gpui::test]
206    async fn test_capture_example(cx: &mut TestAppContext) {
207        init_test(cx);
208        let fs = FakeFs::new(cx.executor());
209
210        let committed_contents = indoc! {"
211            fn main() {
212                one();
213                two();
214                three();
215                four();
216                five();
217                six();
218                seven();
219                eight();
220                nine();
221            }
222        "};
223
224        let disk_contents = indoc! {"
225            fn main() {
226                // comment 1
227                one();
228                two();
229                three();
230                four();
231                five();
232                six();
233                seven();
234                eight();
235                // comment 2
236                nine();
237            }
238        "};
239
240        fs.insert_tree(
241            "/project",
242            json!({
243                ".git": {},
244                "src": {
245                    "main.rs": disk_contents,
246                }
247            }),
248        )
249        .await;
250
251        fs.set_head_for_repo(
252            Path::new("/project/.git"),
253            &[("src/main.rs", committed_contents.to_string())],
254            "abc123def456",
255        );
256        fs.set_remote_for_repo(
257            Path::new("/project/.git"),
258            "origin",
259            "https://github.com/test/repo.git",
260        );
261
262        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
263
264        let buffer = project
265            .update(cx, |project, cx| {
266                project.open_local_buffer("/project/src/main.rs", cx)
267            })
268            .await
269            .unwrap();
270
271        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
272        ep_store.update(cx, |ep_store, cx| {
273            ep_store.register_buffer(&buffer, &project, cx)
274        });
275        cx.run_until_parked();
276
277        buffer.update(cx, |buffer, cx| {
278            let point = Point::new(6, 0);
279            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
280            let point = Point::new(4, 0);
281            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
282
283            pretty_assertions::assert_eq!(
284                buffer.text(),
285                indoc! {"
286                    fn main() {
287                        // comment 1
288                        one();
289                        two();
290                        // comment 4
291                        three();
292                        four();
293                        // comment 3
294                        five();
295                        six();
296                        seven();
297                        eight();
298                        // comment 2
299                        nine();
300                    }
301                "}
302            );
303        });
304        cx.run_until_parked();
305
306        let mut example = cx
307            .update(|cx| capture_example(project.clone(), buffer.clone(), Anchor::MIN, cx).unwrap())
308            .await
309            .unwrap();
310        example.name = "test".to_string();
311
312        pretty_assertions::assert_eq!(
313            example,
314            ExampleSpec {
315                name: "test".to_string(),
316                repository_url: "https://github.com/test/repo.git".to_string(),
317                revision: "abc123def456".to_string(),
318                tags: Vec::new(),
319                reasoning: None,
320                uncommitted_diff: indoc! {"
321                    --- a/project/src/main.rs
322                    +++ b/project/src/main.rs
323                    @@ -1,4 +1,5 @@
324                     fn main() {
325                    +    // comment 1
326                         one();
327                         two();
328                         three();
329                    @@ -7,5 +8,6 @@
330                         six();
331                         seven();
332                         eight();
333                    +    // comment 2
334                         nine();
335                     }
336                "}
337                .to_string(),
338                cursor_path: Path::new("project/src/main.rs").into(),
339                cursor_position: indoc! {"
340                    fn main() {
341                    ^[CURSOR_POSITION]
342                        // comment 1
343                        one();
344                        two();
345                        // comment 4
346                        three();
347                        four();
348                        // comment 3
349                        five();
350                        six();
351                        seven();
352                        eight();
353                        // comment 2
354                        nine();
355                    }
356                "}
357                .to_string(),
358                edit_history: indoc! {"
359                    --- a/project/src/main.rs
360                    +++ b/project/src/main.rs
361                    @@ -2,8 +2,10 @@
362                         // comment 1
363                         one();
364                         two();
365                    +    // comment 4
366                         three();
367                         four();
368                    +    // comment 3
369                         five();
370                         six();
371                         seven();
372                "}
373                .to_string(),
374                expected_patches: Vec::new()
375            }
376        );
377    }
378
379    fn init_test(cx: &mut TestAppContext) {
380        cx.update(|cx| {
381            let settings_store = SettingsStore::test(cx);
382            cx.set_global(settings_store);
383            zlog::init_test();
384            let http_client = FakeHttpClient::with_404_response();
385            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
386            language_model::init(client.clone(), cx);
387            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
388            EditPredictionStore::global(&client, &user_store, cx);
389        })
390    }
391}