1use crate::{
2 EditPredictionStore, StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
4};
5use anyhow::Result;
6use buffer_diff::BufferDiffSnapshot;
7use collections::HashMap;
8use gpui::{App, Entity, Task};
9use language::{Buffer, ToPoint as _};
10use project::Project;
11use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
12use text::BufferSnapshot as TextBufferSnapshot;
13
14pub fn capture_example(
15 project: Entity<Project>,
16 buffer: Entity<Buffer>,
17 cursor_anchor: language::Anchor,
18 cx: &mut App,
19) -> Option<Task<Result<ExampleSpec>>> {
20 let ep_store = EditPredictionStore::try_global(cx)?;
21 let snapshot = buffer.read(cx).snapshot();
22 let file = snapshot.file()?;
23 let worktree_id = file.worktree_id(cx);
24 let repository = project.read(cx).active_repository(cx)?;
25 let repository_snapshot = repository.read(cx).snapshot();
26 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
27 let cursor_path = worktree.read(cx).root_name().join(file.path());
28 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
29 return None;
30 }
31
32 let repository_url = repository_snapshot
33 .remote_origin_url
34 .clone()
35 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
36 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
37
38 let events = ep_store.update(cx, |store, cx| {
39 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
40 });
41
42 let git_store = project.read(cx).git_store().clone();
43
44 Some(cx.spawn(async move |mut cx| {
45 let snapshots_by_path = collect_snapshots(&project, &git_store, &events, &mut cx).await?;
46
47 let line_comment_prefix = snapshot
48 .language()
49 .and_then(|lang| lang.config().line_comments.first())
50 .map(|s| s.to_string())
51 .unwrap_or_default();
52 let (cursor_excerpt, cursor_offset) = cx
53 .background_executor()
54 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
55 .await;
56 let uncommitted_diff = cx
57 .background_executor()
58 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
59 .await;
60
61 let mut edit_history = String::new();
62 for stored_event in &events {
63 zeta_prompt::write_event(&mut edit_history, &stored_event.event);
64 if !edit_history.ends_with('\n') {
65 edit_history.push('\n');
66 }
67 }
68
69 let mut spec = ExampleSpec {
70 name: generate_timestamp_name(),
71 repository_url,
72 revision,
73 uncommitted_diff,
74 cursor_path: cursor_path.as_std_path().into(),
75 cursor_position: String::new(),
76 edit_history,
77 expected_patch: String::new(),
78 };
79 spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
80 Ok(spec)
81 }))
82}
83
84fn compute_cursor_excerpt(
85 snapshot: &language::BufferSnapshot,
86 cursor_anchor: language::Anchor,
87) -> (String, usize) {
88 use text::ToOffset as _;
89
90 let cursor_point = cursor_anchor.to_point(snapshot);
91 let (_editable_range, context_range) =
92 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
93 let context_start_offset = context_range.start.to_offset(snapshot);
94 let cursor_offset = cursor_anchor.to_offset(snapshot);
95 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
96 let excerpt = snapshot.text_for_range(context_range).collect::<String>();
97 (excerpt, cursor_offset_in_excerpt)
98}
99
100async fn collect_snapshots(
101 project: &Entity<Project>,
102 git_store: &Entity<project::git_store::GitStore>,
103 events: &[StoredEvent],
104 cx: &mut gpui::AsyncApp,
105) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
106 let mut snapshots_by_path = HashMap::default();
107 for stored_event in events {
108 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
109 if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
110 let project_path = project.find_project_path(path, cx)?;
111 let full_path = project
112 .worktree_for_id(project_path.worktree_id, cx)?
113 .read(cx)
114 .root_name()
115 .join(&project_path.path)
116 .as_std_path()
117 .into();
118 Some((project_path, full_path))
119 })? {
120 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
121 let buffer = project
122 .update(cx, |project, cx| {
123 project.open_buffer(project_path.clone(), cx)
124 })?
125 .await?;
126 let diff = git_store
127 .update(cx, |git_store, cx| {
128 git_store.open_uncommitted_diff(buffer.clone(), cx)
129 })?
130 .await?;
131 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx))?;
132 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
133 }
134 }
135 }
136 Ok(snapshots_by_path)
137}
138
139fn compute_uncommitted_diff(
140 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
141) -> String {
142 let mut uncommitted_diff = String::new();
143 for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
144 if let Some(head_text) = &diff_snapshot.base_text_string() {
145 let file_diff = language::unified_diff(head_text, &before_text.text());
146 if !file_diff.is_empty() {
147 let path_str = full_path.to_string_lossy();
148 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
149 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
150 uncommitted_diff.push_str(&file_diff);
151 if !uncommitted_diff.ends_with('\n') {
152 uncommitted_diff.push('\n');
153 }
154 }
155 }
156 }
157 uncommitted_diff
158}
159
160fn generate_timestamp_name() -> String {
161 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
162 match format {
163 Ok(format) => {
164 let now = time::OffsetDateTime::now_local()
165 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
166 now.format(&format)
167 .unwrap_or_else(|_| "unknown-time".to_string())
168 }
169 Err(_) => "unknown-time".to_string(),
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use client::{Client, UserStore};
177 use clock::FakeSystemClock;
178 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
179 use indoc::indoc;
180 use language::{Anchor, Point};
181 use project::{FakeFs, Project};
182 use serde_json::json;
183 use settings::SettingsStore;
184 use std::path::Path;
185
186 #[gpui::test]
187 async fn test_capture_example(cx: &mut TestAppContext) {
188 init_test(cx);
189 let fs = FakeFs::new(cx.executor());
190
191 let committed_contents = indoc! {"
192 fn main() {
193 one();
194 two();
195 three();
196 four();
197 five();
198 six();
199 seven();
200 eight();
201 nine();
202 }
203 "};
204
205 let disk_contents = indoc! {"
206 fn main() {
207 // comment 1
208 one();
209 two();
210 three();
211 four();
212 five();
213 six();
214 seven();
215 eight();
216 // comment 2
217 nine();
218 }
219 "};
220
221 fs.insert_tree(
222 "/project",
223 json!({
224 ".git": {},
225 "src": {
226 "main.rs": disk_contents,
227 }
228 }),
229 )
230 .await;
231
232 fs.set_head_for_repo(
233 Path::new("/project/.git"),
234 &[("src/main.rs", committed_contents.to_string())],
235 "abc123def456",
236 );
237 fs.set_remote_for_repo(
238 Path::new("/project/.git"),
239 "origin",
240 "https://github.com/test/repo.git",
241 );
242
243 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
244
245 let buffer = project
246 .update(cx, |project, cx| {
247 project.open_local_buffer("/project/src/main.rs", cx)
248 })
249 .await
250 .unwrap();
251
252 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
253 ep_store.update(cx, |ep_store, cx| {
254 ep_store.register_buffer(&buffer, &project, cx)
255 });
256 cx.run_until_parked();
257
258 buffer.update(cx, |buffer, cx| {
259 let point = Point::new(6, 0);
260 buffer.edit([(point..point, " // comment 3\n")], None, cx);
261 let point = Point::new(4, 0);
262 buffer.edit([(point..point, " // comment 4\n")], None, cx);
263
264 pretty_assertions::assert_eq!(
265 buffer.text(),
266 indoc! {"
267 fn main() {
268 // comment 1
269 one();
270 two();
271 // comment 4
272 three();
273 four();
274 // comment 3
275 five();
276 six();
277 seven();
278 eight();
279 // comment 2
280 nine();
281 }
282 "}
283 );
284 });
285 cx.run_until_parked();
286
287 let mut example = cx
288 .update(|cx| capture_example(project.clone(), buffer.clone(), Anchor::MIN, cx).unwrap())
289 .await
290 .unwrap();
291 example.name = "test".to_string();
292
293 pretty_assertions::assert_eq!(
294 example,
295 ExampleSpec {
296 name: "test".to_string(),
297 repository_url: "https://github.com/test/repo.git".to_string(),
298 revision: "abc123def456".to_string(),
299 uncommitted_diff: indoc! {"
300 --- a/project/src/main.rs
301 +++ b/project/src/main.rs
302 @@ -1,4 +1,5 @@
303 fn main() {
304 + // comment 1
305 one();
306 two();
307 three();
308 @@ -7,5 +8,6 @@
309 six();
310 seven();
311 eight();
312 + // comment 2
313 nine();
314 }
315 "}
316 .to_string(),
317 cursor_path: Path::new("project/src/main.rs").into(),
318 cursor_position: indoc! {"
319 fn main() {
320 ^[CURSOR_POSITION]
321 // comment 1
322 one();
323 two();
324 // comment 4
325 three();
326 four();
327 // comment 3
328 five();
329 six();
330 seven();
331 eight();
332 // comment 2
333 nine();
334 }
335 "}
336 .to_string(),
337 edit_history: indoc! {"
338 --- a/project/src/main.rs
339 +++ b/project/src/main.rs
340 @@ -2,8 +2,10 @@
341 // comment 1
342 one();
343 two();
344 + // comment 4
345 three();
346 four();
347 + // comment 3
348 five();
349 six();
350 seven();
351 "}
352 .to_string(),
353 expected_patch: "".to_string(),
354 }
355 );
356 }
357
358 fn init_test(cx: &mut TestAppContext) {
359 cx.update(|cx| {
360 let settings_store = SettingsStore::test(cx);
361 cx.set_global(settings_store);
362 zlog::init_test();
363 let http_client = FakeHttpClient::with_404_response();
364 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
365 language_model::init(client.clone(), cx);
366 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
367 EditPredictionStore::global(&client, &user_store, cx);
368 })
369 }
370}