1use crate::{
2 EditPredictionStore, StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
4};
5use anyhow::Result;
6use buffer_diff::BufferDiffSnapshot;
7use collections::HashMap;
8use gpui::{App, Entity, Task};
9use language::{Buffer, ToPoint as _};
10use project::Project;
11use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
12use text::{BufferSnapshot as TextBufferSnapshot, ToOffset as _};
13
14pub fn capture_example(
15 project: Entity<Project>,
16 buffer: Entity<Buffer>,
17 cursor_anchor: language::Anchor,
18 last_event_is_expected_patch: bool,
19 cx: &mut App,
20) -> Option<Task<Result<ExampleSpec>>> {
21 let ep_store = EditPredictionStore::try_global(cx)?;
22 let snapshot = buffer.read(cx).snapshot();
23 let file = snapshot.file()?;
24 let worktree_id = file.worktree_id(cx);
25 let repository = project.read(cx).active_repository(cx)?;
26 let repository_snapshot = repository.read(cx).snapshot();
27 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
28 let cursor_path = worktree.read(cx).root_name().join(file.path());
29 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
30 return None;
31 }
32
33 let repository_url = repository_snapshot
34 .remote_origin_url
35 .clone()
36 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
37 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
38
39 let mut events = ep_store.update(cx, |store, cx| {
40 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
41 });
42
43 let git_store = project.read(cx).git_store().clone();
44
45 Some(cx.spawn(async move |mut cx| {
46 let snapshots_by_path = collect_snapshots(&project, &git_store, &events, &mut cx).await?;
47 let cursor_excerpt = cx
48 .background_executor()
49 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
50 .await;
51 let uncommitted_diff = cx
52 .background_executor()
53 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
54 .await;
55
56 let mut edit_history = String::new();
57 let mut expected_patch = String::new();
58 if last_event_is_expected_patch {
59 if let Some(stored_event) = events.pop() {
60 zeta_prompt::write_event(&mut expected_patch, &stored_event.event);
61 }
62 }
63
64 for stored_event in &events {
65 zeta_prompt::write_event(&mut edit_history, &stored_event.event);
66 if !edit_history.ends_with('\n') {
67 edit_history.push('\n');
68 }
69 }
70
71 let name = generate_timestamp_name();
72
73 Ok(ExampleSpec {
74 name,
75 repository_url,
76 revision,
77 uncommitted_diff,
78 cursor_path: cursor_path.as_std_path().into(),
79 cursor_position: cursor_excerpt,
80 edit_history,
81 expected_patch,
82 })
83 }))
84}
85
86fn compute_cursor_excerpt(
87 snapshot: &language::BufferSnapshot,
88 cursor_anchor: language::Anchor,
89) -> String {
90 let cursor_point = cursor_anchor.to_point(snapshot);
91 let (_editable_range, context_range) =
92 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
93
94 let context_start_offset = context_range.start.to_offset(snapshot);
95 let cursor_offset = cursor_anchor.to_offset(snapshot);
96 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
97 let mut excerpt = snapshot.text_for_range(context_range).collect::<String>();
98 if cursor_offset_in_excerpt <= excerpt.len() {
99 excerpt.insert_str(cursor_offset_in_excerpt, zeta_prompt::CURSOR_MARKER);
100 }
101 excerpt
102}
103
104async fn collect_snapshots(
105 project: &Entity<Project>,
106 git_store: &Entity<project::git_store::GitStore>,
107 events: &[StoredEvent],
108 cx: &mut gpui::AsyncApp,
109) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
110 let mut snapshots_by_path = HashMap::default();
111 for stored_event in events {
112 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
113 if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
114 let project_path = project.find_project_path(path, cx)?;
115 let full_path = project
116 .worktree_for_id(project_path.worktree_id, cx)?
117 .read(cx)
118 .root_name()
119 .join(&project_path.path)
120 .as_std_path()
121 .into();
122 Some((project_path, full_path))
123 })? {
124 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
125 let buffer = project
126 .update(cx, |project, cx| {
127 project.open_buffer(project_path.clone(), cx)
128 })?
129 .await?;
130 let diff = git_store
131 .update(cx, |git_store, cx| {
132 git_store.open_uncommitted_diff(buffer.clone(), cx)
133 })?
134 .await?;
135 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx))?;
136 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
137 }
138 }
139 }
140 Ok(snapshots_by_path)
141}
142
143fn compute_uncommitted_diff(
144 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
145) -> String {
146 let mut uncommitted_diff = String::new();
147 for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
148 if let Some(head_text) = &diff_snapshot.base_text_string() {
149 let file_diff = language::unified_diff(head_text, &before_text.text());
150 if !file_diff.is_empty() {
151 let path_str = full_path.to_string_lossy();
152 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
153 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
154 uncommitted_diff.push_str(&file_diff);
155 if !uncommitted_diff.ends_with('\n') {
156 uncommitted_diff.push('\n');
157 }
158 }
159 }
160 }
161 uncommitted_diff
162}
163
164fn generate_timestamp_name() -> String {
165 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
166 match format {
167 Ok(format) => {
168 let now = time::OffsetDateTime::now_local()
169 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
170 now.format(&format)
171 .unwrap_or_else(|_| "unknown-time".to_string())
172 }
173 Err(_) => "unknown-time".to_string(),
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use client::{Client, UserStore};
181 use clock::FakeSystemClock;
182 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
183 use indoc::indoc;
184 use language::{Anchor, Point};
185 use project::{FakeFs, Project};
186 use serde_json::json;
187 use settings::SettingsStore;
188 use std::path::Path;
189
190 #[gpui::test]
191 async fn test_capture_example(cx: &mut TestAppContext) {
192 init_test(cx);
193 let fs = FakeFs::new(cx.executor());
194
195 let committed_contents = indoc! {"
196 fn main() {
197 one();
198 two();
199 three();
200 four();
201 five();
202 six();
203 seven();
204 eight();
205 nine();
206 }
207 "};
208
209 let disk_contents = indoc! {"
210 fn main() {
211 // comment 1
212 one();
213 two();
214 three();
215 four();
216 five();
217 six();
218 seven();
219 eight();
220 // comment 2
221 nine();
222 }
223 "};
224
225 fs.insert_tree(
226 "/project",
227 json!({
228 ".git": {},
229 "src": {
230 "main.rs": disk_contents,
231 }
232 }),
233 )
234 .await;
235
236 fs.set_head_for_repo(
237 Path::new("/project/.git"),
238 &[("src/main.rs", committed_contents.to_string())],
239 "abc123def456",
240 );
241 fs.set_remote_for_repo(
242 Path::new("/project/.git"),
243 "origin",
244 "https://github.com/test/repo.git",
245 );
246
247 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
248
249 let buffer = project
250 .update(cx, |project, cx| {
251 project.open_local_buffer("/project/src/main.rs", cx)
252 })
253 .await
254 .unwrap();
255
256 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
257 ep_store.update(cx, |ep_store, cx| {
258 ep_store.register_buffer(&buffer, &project, cx)
259 });
260 cx.run_until_parked();
261
262 buffer.update(cx, |buffer, cx| {
263 let point = Point::new(6, 0);
264 buffer.edit([(point..point, " // comment 3\n")], None, cx);
265 let point = Point::new(4, 0);
266 buffer.edit([(point..point, " // comment 4\n")], None, cx);
267
268 pretty_assertions::assert_eq!(
269 buffer.text(),
270 indoc! {"
271 fn main() {
272 // comment 1
273 one();
274 two();
275 // comment 4
276 three();
277 four();
278 // comment 3
279 five();
280 six();
281 seven();
282 eight();
283 // comment 2
284 nine();
285 }
286 "}
287 );
288 });
289 cx.run_until_parked();
290
291 let mut example = cx
292 .update(|cx| {
293 capture_example(project.clone(), buffer.clone(), Anchor::MIN, false, cx).unwrap()
294 })
295 .await
296 .unwrap();
297 example.name = "test".to_string();
298
299 pretty_assertions::assert_eq!(
300 example,
301 ExampleSpec {
302 name: "test".to_string(),
303 repository_url: "https://github.com/test/repo.git".to_string(),
304 revision: "abc123def456".to_string(),
305 uncommitted_diff: indoc! {"
306 --- a/project/src/main.rs
307 +++ b/project/src/main.rs
308 @@ -1,4 +1,5 @@
309 fn main() {
310 + // comment 1
311 one();
312 two();
313 three();
314 @@ -7,5 +8,6 @@
315 six();
316 seven();
317 eight();
318 + // comment 2
319 nine();
320 }
321 "}
322 .to_string(),
323 cursor_path: Path::new("project/src/main.rs").into(),
324 cursor_position: indoc! {"
325 <|user_cursor|>fn main() {
326 // comment 1
327 one();
328 two();
329 // comment 4
330 three();
331 four();
332 // comment 3
333 five();
334 six();
335 seven();
336 eight();
337 // comment 2
338 nine();
339 }
340 "}
341 .to_string(),
342 edit_history: indoc! {"
343 --- a/project/src/main.rs
344 +++ b/project/src/main.rs
345 @@ -2,8 +2,10 @@
346 // comment 1
347 one();
348 two();
349 + // comment 4
350 three();
351 four();
352 + // comment 3
353 five();
354 six();
355 seven();
356 "}
357 .to_string(),
358 expected_patch: "".to_string(),
359 }
360 );
361 }
362
363 fn init_test(cx: &mut TestAppContext) {
364 cx.update(|cx| {
365 let settings_store = SettingsStore::test(cx);
366 cx.set_global(settings_store);
367 zlog::init_test();
368 let http_client = FakeHttpClient::with_404_response();
369 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
370 language_model::init(client.clone(), cx);
371 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
372 EditPredictionStore::global(&client, &user_store, cx);
373 })
374 }
375}