capture_example.rs

  1use crate::{
  2    EditPredictionStore, StoredEvent,
  3    cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
  4};
  5use anyhow::Result;
  6use buffer_diff::BufferDiffSnapshot;
  7use collections::HashMap;
  8use gpui::{App, Entity, Task};
  9use language::{Buffer, ToPoint as _};
 10use project::Project;
 11use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
 12use text::{BufferSnapshot as TextBufferSnapshot, ToOffset as _};
 13
 14pub fn capture_example(
 15    project: Entity<Project>,
 16    buffer: Entity<Buffer>,
 17    cursor_anchor: language::Anchor,
 18    last_event_is_expected_patch: bool,
 19    cx: &mut App,
 20) -> Option<Task<Result<ExampleSpec>>> {
 21    let ep_store = EditPredictionStore::try_global(cx)?;
 22    let snapshot = buffer.read(cx).snapshot();
 23    let file = snapshot.file()?;
 24    let worktree_id = file.worktree_id(cx);
 25    let repository = project.read(cx).active_repository(cx)?;
 26    let repository_snapshot = repository.read(cx).snapshot();
 27    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 28    let cursor_path = worktree.read(cx).root_name().join(file.path());
 29    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 30        return None;
 31    }
 32
 33    let repository_url = repository_snapshot
 34        .remote_origin_url
 35        .clone()
 36        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 37    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 38
 39    let mut events = ep_store.update(cx, |store, cx| {
 40        store.edit_history_for_project_with_pause_split_last_event(&project, cx)
 41    });
 42
 43    let git_store = project.read(cx).git_store().clone();
 44
 45    Some(cx.spawn(async move |mut cx| {
 46        let snapshots_by_path = collect_snapshots(&project, &git_store, &events, &mut cx).await?;
 47        let cursor_excerpt = cx
 48            .background_executor()
 49            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 50            .await;
 51        let uncommitted_diff = cx
 52            .background_executor()
 53            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 54            .await;
 55
 56        let mut edit_history = String::new();
 57        let mut expected_patch = String::new();
 58        if last_event_is_expected_patch {
 59            if let Some(stored_event) = events.pop() {
 60                zeta_prompt::write_event(&mut expected_patch, &stored_event.event);
 61            }
 62        }
 63
 64        for stored_event in &events {
 65            zeta_prompt::write_event(&mut edit_history, &stored_event.event);
 66            if !edit_history.ends_with('\n') {
 67                edit_history.push('\n');
 68            }
 69        }
 70
 71        let name = generate_timestamp_name();
 72
 73        Ok(ExampleSpec {
 74            name,
 75            repository_url,
 76            revision,
 77            uncommitted_diff,
 78            cursor_path: cursor_path.as_std_path().into(),
 79            cursor_position: cursor_excerpt,
 80            edit_history,
 81            expected_patch,
 82        })
 83    }))
 84}
 85
 86fn compute_cursor_excerpt(
 87    snapshot: &language::BufferSnapshot,
 88    cursor_anchor: language::Anchor,
 89) -> String {
 90    let cursor_point = cursor_anchor.to_point(snapshot);
 91    let (_editable_range, context_range) =
 92        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
 93
 94    let context_start_offset = context_range.start.to_offset(snapshot);
 95    let cursor_offset = cursor_anchor.to_offset(snapshot);
 96    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
 97    let mut excerpt = snapshot.text_for_range(context_range).collect::<String>();
 98    if cursor_offset_in_excerpt <= excerpt.len() {
 99        excerpt.insert_str(cursor_offset_in_excerpt, zeta_prompt::CURSOR_MARKER);
100    }
101    excerpt
102}
103
104async fn collect_snapshots(
105    project: &Entity<Project>,
106    git_store: &Entity<project::git_store::GitStore>,
107    events: &[StoredEvent],
108    cx: &mut gpui::AsyncApp,
109) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
110    let mut snapshots_by_path = HashMap::default();
111    for stored_event in events {
112        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
113        if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
114            let project_path = project.find_project_path(path, cx)?;
115            let full_path = project
116                .worktree_for_id(project_path.worktree_id, cx)?
117                .read(cx)
118                .root_name()
119                .join(&project_path.path)
120                .as_std_path()
121                .into();
122            Some((project_path, full_path))
123        })? {
124            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
125                let buffer = project
126                    .update(cx, |project, cx| {
127                        project.open_buffer(project_path.clone(), cx)
128                    })?
129                    .await?;
130                let diff = git_store
131                    .update(cx, |git_store, cx| {
132                        git_store.open_uncommitted_diff(buffer.clone(), cx)
133                    })?
134                    .await?;
135                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx))?;
136                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
137            }
138        }
139    }
140    Ok(snapshots_by_path)
141}
142
143fn compute_uncommitted_diff(
144    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
145) -> String {
146    let mut uncommitted_diff = String::new();
147    for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
148        if let Some(head_text) = &diff_snapshot.base_text_string() {
149            let file_diff = language::unified_diff(head_text, &before_text.text());
150            if !file_diff.is_empty() {
151                let path_str = full_path.to_string_lossy();
152                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
153                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
154                uncommitted_diff.push_str(&file_diff);
155                if !uncommitted_diff.ends_with('\n') {
156                    uncommitted_diff.push('\n');
157                }
158            }
159        }
160    }
161    uncommitted_diff
162}
163
164fn generate_timestamp_name() -> String {
165    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
166    match format {
167        Ok(format) => {
168            let now = time::OffsetDateTime::now_local()
169                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
170            now.format(&format)
171                .unwrap_or_else(|_| "unknown-time".to_string())
172        }
173        Err(_) => "unknown-time".to_string(),
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use client::{Client, UserStore};
181    use clock::FakeSystemClock;
182    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
183    use indoc::indoc;
184    use language::{Anchor, Point};
185    use project::{FakeFs, Project};
186    use serde_json::json;
187    use settings::SettingsStore;
188    use std::path::Path;
189
190    #[gpui::test]
191    async fn test_capture_example(cx: &mut TestAppContext) {
192        init_test(cx);
193        let fs = FakeFs::new(cx.executor());
194
195        let committed_contents = indoc! {"
196            fn main() {
197                one();
198                two();
199                three();
200                four();
201                five();
202                six();
203                seven();
204                eight();
205                nine();
206            }
207        "};
208
209        let disk_contents = indoc! {"
210            fn main() {
211                // comment 1
212                one();
213                two();
214                three();
215                four();
216                five();
217                six();
218                seven();
219                eight();
220                // comment 2
221                nine();
222            }
223        "};
224
225        fs.insert_tree(
226            "/project",
227            json!({
228                ".git": {},
229                "src": {
230                    "main.rs": disk_contents,
231                }
232            }),
233        )
234        .await;
235
236        fs.set_head_for_repo(
237            Path::new("/project/.git"),
238            &[("src/main.rs", committed_contents.to_string())],
239            "abc123def456",
240        );
241        fs.set_remote_for_repo(
242            Path::new("/project/.git"),
243            "origin",
244            "https://github.com/test/repo.git",
245        );
246
247        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
248
249        let buffer = project
250            .update(cx, |project, cx| {
251                project.open_local_buffer("/project/src/main.rs", cx)
252            })
253            .await
254            .unwrap();
255
256        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
257        ep_store.update(cx, |ep_store, cx| {
258            ep_store.register_buffer(&buffer, &project, cx)
259        });
260        cx.run_until_parked();
261
262        buffer.update(cx, |buffer, cx| {
263            let point = Point::new(6, 0);
264            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
265            let point = Point::new(4, 0);
266            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
267
268            pretty_assertions::assert_eq!(
269                buffer.text(),
270                indoc! {"
271                    fn main() {
272                        // comment 1
273                        one();
274                        two();
275                        // comment 4
276                        three();
277                        four();
278                        // comment 3
279                        five();
280                        six();
281                        seven();
282                        eight();
283                        // comment 2
284                        nine();
285                    }
286                "}
287            );
288        });
289        cx.run_until_parked();
290
291        let mut example = cx
292            .update(|cx| {
293                capture_example(project.clone(), buffer.clone(), Anchor::MIN, false, cx).unwrap()
294            })
295            .await
296            .unwrap();
297        example.name = "test".to_string();
298
299        pretty_assertions::assert_eq!(
300            example,
301            ExampleSpec {
302                name: "test".to_string(),
303                repository_url: "https://github.com/test/repo.git".to_string(),
304                revision: "abc123def456".to_string(),
305                uncommitted_diff: indoc! {"
306                    --- a/project/src/main.rs
307                    +++ b/project/src/main.rs
308                    @@ -1,4 +1,5 @@
309                     fn main() {
310                    +    // comment 1
311                         one();
312                         two();
313                         three();
314                    @@ -7,5 +8,6 @@
315                         six();
316                         seven();
317                         eight();
318                    +    // comment 2
319                         nine();
320                     }
321                "}
322                .to_string(),
323                cursor_path: Path::new("project/src/main.rs").into(),
324                cursor_position: indoc! {"
325                    <|user_cursor|>fn main() {
326                        // comment 1
327                        one();
328                        two();
329                        // comment 4
330                        three();
331                        four();
332                        // comment 3
333                        five();
334                        six();
335                        seven();
336                        eight();
337                        // comment 2
338                        nine();
339                    }
340                "}
341                .to_string(),
342                edit_history: indoc! {"
343                    --- a/project/src/main.rs
344                    +++ b/project/src/main.rs
345                    @@ -2,8 +2,10 @@
346                         // comment 1
347                         one();
348                         two();
349                    +    // comment 4
350                         three();
351                         four();
352                    +    // comment 3
353                         five();
354                         six();
355                         seven();
356                "}
357                .to_string(),
358                expected_patch: "".to_string(),
359            }
360        );
361    }
362
363    fn init_test(cx: &mut TestAppContext) {
364        cx.update(|cx| {
365            let settings_store = SettingsStore::test(cx);
366            cx.set_global(settings_store);
367            zlog::init_test();
368            let http_client = FakeHttpClient::with_404_response();
369            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
370            language_model::init(client.clone(), cx);
371            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
372            EditPredictionStore::global(&client, &user_store, cx);
373        })
374    }
375}