1use crate::{
2 EditPredictionStore, StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
4};
5use anyhow::Result;
6use buffer_diff::BufferDiffSnapshot;
7use collections::HashMap;
8use gpui::{App, Entity, Task};
9use language::{Buffer, ToPoint as _};
10use project::{Project, WorktreeId};
11use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
12use text::BufferSnapshot as TextBufferSnapshot;
13
14pub fn capture_example(
15 project: Entity<Project>,
16 buffer: Entity<Buffer>,
17 cursor_anchor: language::Anchor,
18 cx: &mut App,
19) -> Option<Task<Result<ExampleSpec>>> {
20 let ep_store = EditPredictionStore::try_global(cx)?;
21 let snapshot = buffer.read(cx).snapshot();
22 let file = snapshot.file()?;
23 let worktree_id = file.worktree_id(cx);
24 let repository = project.read(cx).active_repository(cx)?;
25 let repository_snapshot = repository.read(cx).snapshot();
26 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
27 let cursor_path = worktree.read(cx).root_name().join(file.path());
28 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
29 return None;
30 }
31
32 let repository_url = repository_snapshot
33 .remote_origin_url
34 .clone()
35 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
36 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
37
38 let mut events = ep_store.update(cx, |store, cx| {
39 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
40 });
41
42 let git_store = project.read(cx).git_store().clone();
43
44 Some(cx.spawn(async move |mut cx| {
45 let snapshots_by_path =
46 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
47
48 events.retain(|stored_event| {
49 match stored_event.event.as_ref() {
50 zeta_prompt::Event::BufferChange { path, .. } => {
51 if !snapshots_by_path.contains_key(path) {
52 return false;
53 }
54 }
55 }
56 true
57 });
58
59 let line_comment_prefix = snapshot
60 .language()
61 .and_then(|lang| lang.config().line_comments.first())
62 .map(|s| s.to_string())
63 .unwrap_or_default();
64 let (cursor_excerpt, cursor_offset) = cx
65 .background_executor()
66 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
67 .await;
68 let uncommitted_diff = cx
69 .background_executor()
70 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
71 .await;
72
73 let mut edit_history = String::new();
74 for stored_event in &events {
75 zeta_prompt::write_event(&mut edit_history, &stored_event.event);
76 if !edit_history.ends_with('\n') {
77 edit_history.push('\n');
78 }
79 }
80
81 let mut spec = ExampleSpec {
82 name: generate_timestamp_name(),
83 repository_url,
84 revision,
85 uncommitted_diff,
86 cursor_path: cursor_path.as_std_path().into(),
87 cursor_position: String::new(),
88 edit_history,
89 expected_patches: Vec::new(),
90 };
91 spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
92 Ok(spec)
93 }))
94}
95
96fn compute_cursor_excerpt(
97 snapshot: &language::BufferSnapshot,
98 cursor_anchor: language::Anchor,
99) -> (String, usize) {
100 use text::ToOffset as _;
101
102 let cursor_point = cursor_anchor.to_point(snapshot);
103 let (_editable_range, context_range) =
104 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
105 let context_start_offset = context_range.start.to_offset(snapshot);
106 let cursor_offset = cursor_anchor.to_offset(snapshot);
107 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
108 let excerpt = snapshot.text_for_range(context_range).collect::<String>();
109 (excerpt, cursor_offset_in_excerpt)
110}
111
112async fn collect_snapshots(
113 project: &Entity<Project>,
114 git_store: &Entity<project::git_store::GitStore>,
115 worktree_id: WorktreeId,
116 events: &[StoredEvent],
117 cx: &mut gpui::AsyncApp,
118) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
119 let mut snapshots_by_path = HashMap::default();
120 let root_name = project.read_with(cx, |project, cx| {
121 project
122 .worktree_for_id(worktree_id, cx)
123 .unwrap()
124 .read(cx)
125 .root_name()
126 .to_owned()
127 })?;
128 for stored_event in events {
129 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
130 if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
131 let project_path = project
132 .find_project_path(path, cx)
133 .filter(|path| path.worktree_id == worktree_id)?;
134 let full_path = root_name.join(&project_path.path).as_std_path().into();
135 Some((project_path, full_path))
136 })? {
137 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
138 let buffer = project
139 .update(cx, |project, cx| {
140 project.open_buffer(project_path.clone(), cx)
141 })?
142 .await?;
143 let diff = git_store
144 .update(cx, |git_store, cx| {
145 git_store.open_uncommitted_diff(buffer.clone(), cx)
146 })?
147 .await?;
148 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx))?;
149 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
150 }
151 }
152 }
153 Ok(snapshots_by_path)
154}
155
156fn compute_uncommitted_diff(
157 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
158) -> String {
159 let mut uncommitted_diff = String::new();
160 for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
161 if let Some(head_text) = &diff_snapshot.base_text_string() {
162 let file_diff = language::unified_diff(head_text, &before_text.text());
163 if !file_diff.is_empty() {
164 let path_str = full_path.to_string_lossy();
165 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
166 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
167 uncommitted_diff.push_str(&file_diff);
168 if !uncommitted_diff.ends_with('\n') {
169 uncommitted_diff.push('\n');
170 }
171 }
172 }
173 }
174 uncommitted_diff
175}
176
177fn generate_timestamp_name() -> String {
178 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
179 match format {
180 Ok(format) => {
181 let now = time::OffsetDateTime::now_local()
182 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
183 now.format(&format)
184 .unwrap_or_else(|_| "unknown-time".to_string())
185 }
186 Err(_) => "unknown-time".to_string(),
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use client::{Client, UserStore};
194 use clock::FakeSystemClock;
195 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
196 use indoc::indoc;
197 use language::{Anchor, Point};
198 use project::{FakeFs, Project};
199 use serde_json::json;
200 use settings::SettingsStore;
201 use std::path::Path;
202
203 #[gpui::test]
204 async fn test_capture_example(cx: &mut TestAppContext) {
205 init_test(cx);
206 let fs = FakeFs::new(cx.executor());
207
208 let committed_contents = indoc! {"
209 fn main() {
210 one();
211 two();
212 three();
213 four();
214 five();
215 six();
216 seven();
217 eight();
218 nine();
219 }
220 "};
221
222 let disk_contents = indoc! {"
223 fn main() {
224 // comment 1
225 one();
226 two();
227 three();
228 four();
229 five();
230 six();
231 seven();
232 eight();
233 // comment 2
234 nine();
235 }
236 "};
237
238 fs.insert_tree(
239 "/project",
240 json!({
241 ".git": {},
242 "src": {
243 "main.rs": disk_contents,
244 }
245 }),
246 )
247 .await;
248
249 fs.set_head_for_repo(
250 Path::new("/project/.git"),
251 &[("src/main.rs", committed_contents.to_string())],
252 "abc123def456",
253 );
254 fs.set_remote_for_repo(
255 Path::new("/project/.git"),
256 "origin",
257 "https://github.com/test/repo.git",
258 );
259
260 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
261
262 let buffer = project
263 .update(cx, |project, cx| {
264 project.open_local_buffer("/project/src/main.rs", cx)
265 })
266 .await
267 .unwrap();
268
269 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
270 ep_store.update(cx, |ep_store, cx| {
271 ep_store.register_buffer(&buffer, &project, cx)
272 });
273 cx.run_until_parked();
274
275 buffer.update(cx, |buffer, cx| {
276 let point = Point::new(6, 0);
277 buffer.edit([(point..point, " // comment 3\n")], None, cx);
278 let point = Point::new(4, 0);
279 buffer.edit([(point..point, " // comment 4\n")], None, cx);
280
281 pretty_assertions::assert_eq!(
282 buffer.text(),
283 indoc! {"
284 fn main() {
285 // comment 1
286 one();
287 two();
288 // comment 4
289 three();
290 four();
291 // comment 3
292 five();
293 six();
294 seven();
295 eight();
296 // comment 2
297 nine();
298 }
299 "}
300 );
301 });
302 cx.run_until_parked();
303
304 let mut example = cx
305 .update(|cx| capture_example(project.clone(), buffer.clone(), Anchor::MIN, cx).unwrap())
306 .await
307 .unwrap();
308 example.name = "test".to_string();
309
310 pretty_assertions::assert_eq!(
311 example,
312 ExampleSpec {
313 name: "test".to_string(),
314 repository_url: "https://github.com/test/repo.git".to_string(),
315 revision: "abc123def456".to_string(),
316 uncommitted_diff: indoc! {"
317 --- a/project/src/main.rs
318 +++ b/project/src/main.rs
319 @@ -1,4 +1,5 @@
320 fn main() {
321 + // comment 1
322 one();
323 two();
324 three();
325 @@ -7,5 +8,6 @@
326 six();
327 seven();
328 eight();
329 + // comment 2
330 nine();
331 }
332 "}
333 .to_string(),
334 cursor_path: Path::new("project/src/main.rs").into(),
335 cursor_position: indoc! {"
336 fn main() {
337 ^[CURSOR_POSITION]
338 // comment 1
339 one();
340 two();
341 // comment 4
342 three();
343 four();
344 // comment 3
345 five();
346 six();
347 seven();
348 eight();
349 // comment 2
350 nine();
351 }
352 "}
353 .to_string(),
354 edit_history: indoc! {"
355 --- a/project/src/main.rs
356 +++ b/project/src/main.rs
357 @@ -2,8 +2,10 @@
358 // comment 1
359 one();
360 two();
361 + // comment 4
362 three();
363 four();
364 + // comment 3
365 five();
366 six();
367 seven();
368 "}
369 .to_string(),
370 expected_patches: Vec::new()
371 }
372 );
373 }
374
375 fn init_test(cx: &mut TestAppContext) {
376 cx.update(|cx| {
377 let settings_store = SettingsStore::test(cx);
378 cx.set_global(settings_store);
379 zlog::init_test();
380 let http_client = FakeHttpClient::with_404_response();
381 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
382 language_model::init(client.clone(), cx);
383 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
384 EditPredictionStore::global(&client, &user_store, cx);
385 })
386 }
387}