capture_example.rs

  1use crate::{
  2    EditPredictionStore, StoredEvent,
  3    cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
  4};
  5use anyhow::Result;
  6use buffer_diff::BufferDiffSnapshot;
  7use collections::HashMap;
  8use gpui::{App, Entity, Task};
  9use language::{Buffer, ToPoint as _};
 10use project::Project;
 11use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
 12use text::{BufferSnapshot as TextBufferSnapshot, ToOffset as _};
 13
 14pub fn capture_example(
 15    project: Entity<Project>,
 16    buffer: Entity<Buffer>,
 17    cursor_anchor: language::Anchor,
 18    cx: &mut App,
 19) -> Option<Task<Result<ExampleSpec>>> {
 20    let ep_store = EditPredictionStore::try_global(cx)?;
 21    let snapshot = buffer.read(cx).snapshot();
 22    let file = snapshot.file()?;
 23    let worktree_id = file.worktree_id(cx);
 24    let repository = project.read(cx).active_repository(cx)?;
 25    let repository_snapshot = repository.read(cx).snapshot();
 26    let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
 27    let cursor_path = worktree.read(cx).root_name().join(file.path());
 28    if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
 29        return None;
 30    }
 31
 32    let repository_url = repository_snapshot
 33        .remote_origin_url
 34        .clone()
 35        .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
 36    let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 37
 38    let events = ep_store.update(cx, |store, cx| {
 39        store.edit_history_for_project_with_pause_split_last_event(&project, cx)
 40    });
 41
 42    let git_store = project.read(cx).git_store().clone();
 43
 44    Some(cx.spawn(async move |mut cx| {
 45        let snapshots_by_path = collect_snapshots(&project, &git_store, &events, &mut cx).await?;
 46        let cursor_excerpt = cx
 47            .background_executor()
 48            .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
 49            .await;
 50        let uncommitted_diff = cx
 51            .background_executor()
 52            .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
 53            .await;
 54
 55        let mut edit_history = String::new();
 56        for stored_event in &events {
 57            zeta_prompt::write_event(&mut edit_history, &stored_event.event);
 58            if !edit_history.ends_with('\n') {
 59                edit_history.push('\n');
 60            }
 61        }
 62
 63        Ok(ExampleSpec {
 64            name: generate_timestamp_name(),
 65            repository_url,
 66            revision,
 67            uncommitted_diff,
 68            cursor_path: cursor_path.as_std_path().into(),
 69            cursor_position: cursor_excerpt,
 70            edit_history,
 71            expected_patch: String::new(),
 72        })
 73    }))
 74}
 75
 76fn compute_cursor_excerpt(
 77    snapshot: &language::BufferSnapshot,
 78    cursor_anchor: language::Anchor,
 79) -> String {
 80    let cursor_point = cursor_anchor.to_point(snapshot);
 81    let (_editable_range, context_range) =
 82        editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
 83
 84    let context_start_offset = context_range.start.to_offset(snapshot);
 85    let cursor_offset = cursor_anchor.to_offset(snapshot);
 86    let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
 87    let mut excerpt = snapshot.text_for_range(context_range).collect::<String>();
 88    if cursor_offset_in_excerpt <= excerpt.len() {
 89        excerpt.insert_str(cursor_offset_in_excerpt, zeta_prompt::CURSOR_MARKER);
 90    }
 91    excerpt
 92}
 93
 94async fn collect_snapshots(
 95    project: &Entity<Project>,
 96    git_store: &Entity<project::git_store::GitStore>,
 97    events: &[StoredEvent],
 98    cx: &mut gpui::AsyncApp,
 99) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
100    let mut snapshots_by_path = HashMap::default();
101    for stored_event in events {
102        let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
103        if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
104            let project_path = project.find_project_path(path, cx)?;
105            let full_path = project
106                .worktree_for_id(project_path.worktree_id, cx)?
107                .read(cx)
108                .root_name()
109                .join(&project_path.path)
110                .as_std_path()
111                .into();
112            Some((project_path, full_path))
113        })? {
114            if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
115                let buffer = project
116                    .update(cx, |project, cx| {
117                        project.open_buffer(project_path.clone(), cx)
118                    })?
119                    .await?;
120                let diff = git_store
121                    .update(cx, |git_store, cx| {
122                        git_store.open_uncommitted_diff(buffer.clone(), cx)
123                    })?
124                    .await?;
125                let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx))?;
126                entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
127            }
128        }
129    }
130    Ok(snapshots_by_path)
131}
132
133fn compute_uncommitted_diff(
134    snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
135) -> String {
136    let mut uncommitted_diff = String::new();
137    for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
138        if let Some(head_text) = &diff_snapshot.base_text_string() {
139            let file_diff = language::unified_diff(head_text, &before_text.text());
140            if !file_diff.is_empty() {
141                let path_str = full_path.to_string_lossy();
142                writeln!(uncommitted_diff, "--- a/{path_str}").ok();
143                writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
144                uncommitted_diff.push_str(&file_diff);
145                if !uncommitted_diff.ends_with('\n') {
146                    uncommitted_diff.push('\n');
147                }
148            }
149        }
150    }
151    uncommitted_diff
152}
153
154fn generate_timestamp_name() -> String {
155    let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
156    match format {
157        Ok(format) => {
158            let now = time::OffsetDateTime::now_local()
159                .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
160            now.format(&format)
161                .unwrap_or_else(|_| "unknown-time".to_string())
162        }
163        Err(_) => "unknown-time".to_string(),
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use client::{Client, UserStore};
171    use clock::FakeSystemClock;
172    use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
173    use indoc::indoc;
174    use language::{Anchor, Point};
175    use project::{FakeFs, Project};
176    use serde_json::json;
177    use settings::SettingsStore;
178    use std::path::Path;
179
180    #[gpui::test]
181    async fn test_capture_example(cx: &mut TestAppContext) {
182        init_test(cx);
183        let fs = FakeFs::new(cx.executor());
184
185        let committed_contents = indoc! {"
186            fn main() {
187                one();
188                two();
189                three();
190                four();
191                five();
192                six();
193                seven();
194                eight();
195                nine();
196            }
197        "};
198
199        let disk_contents = indoc! {"
200            fn main() {
201                // comment 1
202                one();
203                two();
204                three();
205                four();
206                five();
207                six();
208                seven();
209                eight();
210                // comment 2
211                nine();
212            }
213        "};
214
215        fs.insert_tree(
216            "/project",
217            json!({
218                ".git": {},
219                "src": {
220                    "main.rs": disk_contents,
221                }
222            }),
223        )
224        .await;
225
226        fs.set_head_for_repo(
227            Path::new("/project/.git"),
228            &[("src/main.rs", committed_contents.to_string())],
229            "abc123def456",
230        );
231        fs.set_remote_for_repo(
232            Path::new("/project/.git"),
233            "origin",
234            "https://github.com/test/repo.git",
235        );
236
237        let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
238
239        let buffer = project
240            .update(cx, |project, cx| {
241                project.open_local_buffer("/project/src/main.rs", cx)
242            })
243            .await
244            .unwrap();
245
246        let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
247        ep_store.update(cx, |ep_store, cx| {
248            ep_store.register_buffer(&buffer, &project, cx)
249        });
250        cx.run_until_parked();
251
252        buffer.update(cx, |buffer, cx| {
253            let point = Point::new(6, 0);
254            buffer.edit([(point..point, "    // comment 3\n")], None, cx);
255            let point = Point::new(4, 0);
256            buffer.edit([(point..point, "    // comment 4\n")], None, cx);
257
258            pretty_assertions::assert_eq!(
259                buffer.text(),
260                indoc! {"
261                    fn main() {
262                        // comment 1
263                        one();
264                        two();
265                        // comment 4
266                        three();
267                        four();
268                        // comment 3
269                        five();
270                        six();
271                        seven();
272                        eight();
273                        // comment 2
274                        nine();
275                    }
276                "}
277            );
278        });
279        cx.run_until_parked();
280
281        let mut example = cx
282            .update(|cx| capture_example(project.clone(), buffer.clone(), Anchor::MIN, cx).unwrap())
283            .await
284            .unwrap();
285        example.name = "test".to_string();
286
287        pretty_assertions::assert_eq!(
288            example,
289            ExampleSpec {
290                name: "test".to_string(),
291                repository_url: "https://github.com/test/repo.git".to_string(),
292                revision: "abc123def456".to_string(),
293                uncommitted_diff: indoc! {"
294                    --- a/project/src/main.rs
295                    +++ b/project/src/main.rs
296                    @@ -1,4 +1,5 @@
297                     fn main() {
298                    +    // comment 1
299                         one();
300                         two();
301                         three();
302                    @@ -7,5 +8,6 @@
303                         six();
304                         seven();
305                         eight();
306                    +    // comment 2
307                         nine();
308                     }
309                "}
310                .to_string(),
311                cursor_path: Path::new("project/src/main.rs").into(),
312                cursor_position: indoc! {"
313                    <|user_cursor|>fn main() {
314                        // comment 1
315                        one();
316                        two();
317                        // comment 4
318                        three();
319                        four();
320                        // comment 3
321                        five();
322                        six();
323                        seven();
324                        eight();
325                        // comment 2
326                        nine();
327                    }
328                "}
329                .to_string(),
330                edit_history: indoc! {"
331                    --- a/project/src/main.rs
332                    +++ b/project/src/main.rs
333                    @@ -2,8 +2,10 @@
334                         // comment 1
335                         one();
336                         two();
337                    +    // comment 4
338                         three();
339                         four();
340                    +    // comment 3
341                         five();
342                         six();
343                         seven();
344                "}
345                .to_string(),
346                expected_patch: "".to_string(),
347            }
348        );
349    }
350
351    fn init_test(cx: &mut TestAppContext) {
352        cx.update(|cx| {
353            let settings_store = SettingsStore::test(cx);
354            cx.set_global(settings_store);
355            zlog::init_test();
356            let http_client = FakeHttpClient::with_404_response();
357            let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
358            language_model::init(client.clone(), cx);
359            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
360            EditPredictionStore::global(&client, &user_store, cx);
361        })
362    }
363}