capture_example.rs

  1use crate::{
  2    EditPredictionStore, StoredEvent,
  3    cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
  4};
  5use anyhow::Result;
  6use buffer_diff::BufferDiffSnapshot;
  7use collections::HashMap;
  8use gpui::{App, Entity, Task};
  9use language::{Buffer, ToPoint as _};
 10use project::Project;
 11use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
 12use text::BufferSnapshot as TextBufferSnapshot;
 13
 14pub fn capture_example(
 15    project: Entity<Project>,
 16    buffer: Entity<Buffer>,
 17    cursor_anchor: language::Anchor,
 18    cx: &mut App,
 19) -> Option<Task<Result<ExampleSpec>>> {
 20    let ep_store = EditPredictionStore::try_global(cx)?;
 21    let snapshot = buffer.read(cx).snapshot();
 22    let file = snapshot.file()?;
 23    let worktree_id = file.worktree_id(cx);
 24    let repository = project.read(cx).active_repository(cx)?;
 25    let repository_snapshot = repository.read(cx).snapshot();
 26    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 27    let cursor_path = worktree.read(cx).root_name().join(file.path());
 28    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 29        return None;
 30    }
 31
 32    let repository_url = repository_snapshot
 33        .remote_origin_url
 34        .clone()
 35        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 36    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 37
 38    let events = ep_store.update(cx, |store, cx| {
 39        store.edit_history_for_project_with_pause_split_last_event(&project, cx)
 40    });
 41
 42    let git_store = project.read(cx).git_store().clone();
 43
 44    Some(cx.spawn(async move |mut cx| {
 45        let snapshots_by_path = collect_snapshots(&project, &git_store, &events, &mut cx).await?;
 46
 47        let line_comment_prefix = snapshot
 48            .language()
 49            .and_then(|lang| lang.config().line_comments.first())
 50            .map(|s| s.to_string())
 51            .unwrap_or_default();
 52        let (cursor_excerpt, cursor_offset) = cx
 53            .background_executor()
 54            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 55            .await;
 56        let uncommitted_diff = cx
 57            .background_executor()
 58            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 59            .await;
 60
 61        let mut edit_history = String::new();
 62        for stored_event in &events {
 63            zeta_prompt::write_event(&mut edit_history, &stored_event.event);
 64            if !edit_history.ends_with('\n') {
 65                edit_history.push('\n');
 66            }
 67        }
 68
 69        let mut spec = ExampleSpec {
 70            name: generate_timestamp_name(),
 71            repository_url,
 72            revision,
 73            uncommitted_diff,
 74            cursor_path: cursor_path.as_std_path().into(),
 75            cursor_position: String::new(),
 76            edit_history,
 77            expected_patch: String::new(),
 78        };
 79        spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
 80        Ok(spec)
 81    }))
 82}
 83
 84fn compute_cursor_excerpt(
 85    snapshot: &language::BufferSnapshot,
 86    cursor_anchor: language::Anchor,
 87) -> (String, usize) {
 88    use text::ToOffset as _;
 89
 90    let cursor_point = cursor_anchor.to_point(snapshot);
 91    let (_editable_range, context_range) =
 92        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
 93    let context_start_offset = context_range.start.to_offset(snapshot);
 94    let cursor_offset = cursor_anchor.to_offset(snapshot);
 95    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
 96    let excerpt = snapshot.text_for_range(context_range).collect::<String>();
 97    (excerpt, cursor_offset_in_excerpt)
 98}
 99
100async fn collect_snapshots(
101    project: &Entity<Project>,
102    git_store: &Entity<project::git_store::GitStore>,
103    events: &[StoredEvent],
104    cx: &mut gpui::AsyncApp,
105) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
106    let mut snapshots_by_path = HashMap::default();
107    for stored_event in events {
108        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
109        if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
110            let project_path = project.find_project_path(path, cx)?;
111            let full_path = project
112                .worktree_for_id(project_path.worktree_id, cx)?
113                .read(cx)
114                .root_name()
115                .join(&project_path.path)
116                .as_std_path()
117                .into();
118            Some((project_path, full_path))
119        })? {
120            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
121                let buffer = project
122                    .update(cx, |project, cx| {
123                        project.open_buffer(project_path.clone(), cx)
124                    })?
125                    .await?;
126                let diff = git_store
127                    .update(cx, |git_store, cx| {
128                        git_store.open_uncommitted_diff(buffer.clone(), cx)
129                    })?
130                    .await?;
131                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx))?;
132                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
133            }
134        }
135    }
136    Ok(snapshots_by_path)
137}
138
139fn compute_uncommitted_diff(
140    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
141) -> String {
142    let mut uncommitted_diff = String::new();
143    for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
144        if let Some(head_text) = &diff_snapshot.base_text_string() {
145            let file_diff = language::unified_diff(head_text, &before_text.text());
146            if !file_diff.is_empty() {
147                let path_str = full_path.to_string_lossy();
148                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
149                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
150                uncommitted_diff.push_str(&file_diff);
151                if !uncommitted_diff.ends_with('\n') {
152                    uncommitted_diff.push('\n');
153                }
154            }
155        }
156    }
157    uncommitted_diff
158}
159
160fn generate_timestamp_name() -> String {
161    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
162    match format {
163        Ok(format) => {
164            let now = time::OffsetDateTime::now_local()
165                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
166            now.format(&format)
167                .unwrap_or_else(|_| "unknown-time".to_string())
168        }
169        Err(_) => "unknown-time".to_string(),
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use client::{Client, UserStore};
177    use clock::FakeSystemClock;
178    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
179    use indoc::indoc;
180    use language::{Anchor, Point};
181    use project::{FakeFs, Project};
182    use serde_json::json;
183    use settings::SettingsStore;
184    use std::path::Path;
185
186    #[gpui::test]
187    async fn test_capture_example(cx: &mut TestAppContext) {
188        init_test(cx);
189        let fs = FakeFs::new(cx.executor());
190
191        let committed_contents = indoc! {"
192            fn main() {
193                one();
194                two();
195                three();
196                four();
197                five();
198                six();
199                seven();
200                eight();
201                nine();
202            }
203        "};
204
205        let disk_contents = indoc! {"
206            fn main() {
207                // comment 1
208                one();
209                two();
210                three();
211                four();
212                five();
213                six();
214                seven();
215                eight();
216                // comment 2
217                nine();
218            }
219        "};
220
221        fs.insert_tree(
222            "/project",
223            json!({
224                ".git": {},
225                "src": {
226                    "main.rs": disk_contents,
227                }
228            }),
229        )
230        .await;
231
232        fs.set_head_for_repo(
233            Path::new("/project/.git"),
234            &[("src/main.rs", committed_contents.to_string())],
235            "abc123def456",
236        );
237        fs.set_remote_for_repo(
238            Path::new("/project/.git"),
239            "origin",
240            "https://github.com/test/repo.git",
241        );
242
243        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
244
245        let buffer = project
246            .update(cx, |project, cx| {
247                project.open_local_buffer("/project/src/main.rs", cx)
248            })
249            .await
250            .unwrap();
251
252        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
253        ep_store.update(cx, |ep_store, cx| {
254            ep_store.register_buffer(&buffer, &project, cx)
255        });
256        cx.run_until_parked();
257
258        buffer.update(cx, |buffer, cx| {
259            let point = Point::new(6, 0);
260            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
261            let point = Point::new(4, 0);
262            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
263
264            pretty_assertions::assert_eq!(
265                buffer.text(),
266                indoc! {"
267                    fn main() {
268                        // comment 1
269                        one();
270                        two();
271                        // comment 4
272                        three();
273                        four();
274                        // comment 3
275                        five();
276                        six();
277                        seven();
278                        eight();
279                        // comment 2
280                        nine();
281                    }
282                "}
283            );
284        });
285        cx.run_until_parked();
286
287        let mut example = cx
288            .update(|cx| capture_example(project.clone(), buffer.clone(), Anchor::MIN, cx).unwrap())
289            .await
290            .unwrap();
291        example.name = "test".to_string();
292
293        pretty_assertions::assert_eq!(
294            example,
295            ExampleSpec {
296                name: "test".to_string(),
297                repository_url: "https://github.com/test/repo.git".to_string(),
298                revision: "abc123def456".to_string(),
299                uncommitted_diff: indoc! {"
300                    --- a/project/src/main.rs
301                    +++ b/project/src/main.rs
302                    @@ -1,4 +1,5 @@
303                     fn main() {
304                    +    // comment 1
305                         one();
306                         two();
307                         three();
308                    @@ -7,5 +8,6 @@
309                         six();
310                         seven();
311                         eight();
312                    +    // comment 2
313                         nine();
314                     }
315                "}
316                .to_string(),
317                cursor_path: Path::new("project/src/main.rs").into(),
318                cursor_position: indoc! {"
319                    fn main() {
320                    ^[CURSOR_POSITION]
321                        // comment 1
322                        one();
323                        two();
324                        // comment 4
325                        three();
326                        four();
327                        // comment 3
328                        five();
329                        six();
330                        seven();
331                        eight();
332                        // comment 2
333                        nine();
334                    }
335                "}
336                .to_string(),
337                edit_history: indoc! {"
338                    --- a/project/src/main.rs
339                    +++ b/project/src/main.rs
340                    @@ -2,8 +2,10 @@
341                         // comment 1
342                         one();
343                         two();
344                    +    // comment 4
345                         three();
346                         four();
347                    +    // comment 3
348                         five();
349                         six();
350                         seven();
351                "}
352                .to_string(),
353                expected_patch: "".to_string(),
354            }
355        );
356    }
357
358    fn init_test(cx: &mut TestAppContext) {
359        cx.update(|cx| {
360            let settings_store = SettingsStore::test(cx);
361            cx.set_global(settings_store);
362            zlog::init_test();
363            let http_client = FakeHttpClient::with_404_response();
364            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
365            language_model::init(client.clone(), cx);
366            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
367            EditPredictionStore::global(&client, &user_store, cx);
368        })
369    }
370}