1use crate::{
2 EditPredictionStore, StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
4};
5use anyhow::Result;
6use buffer_diff::BufferDiffSnapshot;
7use collections::HashMap;
8use gpui::{App, Entity, Task};
9use language::{Buffer, ToPoint as _};
10use project::Project;
11use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
12use text::{BufferSnapshot as TextBufferSnapshot, ToOffset as _};
13
14pub fn capture_example(
15 project: Entity<Project>,
16 buffer: Entity<Buffer>,
17 cursor_anchor: language::Anchor,
18 cx: &mut App,
19) -> Option<Task<Result<ExampleSpec>>> {
20 let ep_store = EditPredictionStore::try_global(cx)?;
21 let snapshot = buffer.read(cx).snapshot();
22 let file = snapshot.file()?;
23 let worktree_id = file.worktree_id(cx);
24 let repository = project.read(cx).active_repository(cx)?;
25 let repository_snapshot = repository.read(cx).snapshot();
26 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
27 let cursor_path = worktree.read(cx).root_name().join(file.path());
28 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
29 return None;
30 }
31
32 let repository_url = repository_snapshot
33 .remote_origin_url
34 .clone()
35 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
36 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
37
38 let events = ep_store.update(cx, |store, cx| {
39 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
40 });
41
42 let git_store = project.read(cx).git_store().clone();
43
44 Some(cx.spawn(async move |mut cx| {
45 let snapshots_by_path = collect_snapshots(&project, &git_store, &events, &mut cx).await?;
46 let cursor_excerpt = cx
47 .background_executor()
48 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
49 .await;
50 let uncommitted_diff = cx
51 .background_executor()
52 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
53 .await;
54
55 let mut edit_history = String::new();
56 for stored_event in &events {
57 zeta_prompt::write_event(&mut edit_history, &stored_event.event);
58 if !edit_history.ends_with('\n') {
59 edit_history.push('\n');
60 }
61 }
62
63 Ok(ExampleSpec {
64 name: generate_timestamp_name(),
65 repository_url,
66 revision,
67 uncommitted_diff,
68 cursor_path: cursor_path.as_std_path().into(),
69 cursor_position: cursor_excerpt,
70 edit_history,
71 expected_patch: String::new(),
72 })
73 }))
74}
75
76fn compute_cursor_excerpt(
77 snapshot: &language::BufferSnapshot,
78 cursor_anchor: language::Anchor,
79) -> String {
80 let cursor_point = cursor_anchor.to_point(snapshot);
81 let (_editable_range, context_range) =
82 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
83
84 let context_start_offset = context_range.start.to_offset(snapshot);
85 let cursor_offset = cursor_anchor.to_offset(snapshot);
86 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
87 let mut excerpt = snapshot.text_for_range(context_range).collect::<String>();
88 if cursor_offset_in_excerpt <= excerpt.len() {
89 excerpt.insert_str(cursor_offset_in_excerpt, zeta_prompt::CURSOR_MARKER);
90 }
91 excerpt
92}
93
94async fn collect_snapshots(
95 project: &Entity<Project>,
96 git_store: &Entity<project::git_store::GitStore>,
97 events: &[StoredEvent],
98 cx: &mut gpui::AsyncApp,
99) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
100 let mut snapshots_by_path = HashMap::default();
101 for stored_event in events {
102 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
103 if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
104 let project_path = project.find_project_path(path, cx)?;
105 let full_path = project
106 .worktree_for_id(project_path.worktree_id, cx)?
107 .read(cx)
108 .root_name()
109 .join(&project_path.path)
110 .as_std_path()
111 .into();
112 Some((project_path, full_path))
113 })? {
114 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
115 let buffer = project
116 .update(cx, |project, cx| {
117 project.open_buffer(project_path.clone(), cx)
118 })?
119 .await?;
120 let diff = git_store
121 .update(cx, |git_store, cx| {
122 git_store.open_uncommitted_diff(buffer.clone(), cx)
123 })?
124 .await?;
125 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx))?;
126 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
127 }
128 }
129 }
130 Ok(snapshots_by_path)
131}
132
133fn compute_uncommitted_diff(
134 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
135) -> String {
136 let mut uncommitted_diff = String::new();
137 for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
138 if let Some(head_text) = &diff_snapshot.base_text_string() {
139 let file_diff = language::unified_diff(head_text, &before_text.text());
140 if !file_diff.is_empty() {
141 let path_str = full_path.to_string_lossy();
142 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
143 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
144 uncommitted_diff.push_str(&file_diff);
145 if !uncommitted_diff.ends_with('\n') {
146 uncommitted_diff.push('\n');
147 }
148 }
149 }
150 }
151 uncommitted_diff
152}
153
154fn generate_timestamp_name() -> String {
155 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
156 match format {
157 Ok(format) => {
158 let now = time::OffsetDateTime::now_local()
159 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
160 now.format(&format)
161 .unwrap_or_else(|_| "unknown-time".to_string())
162 }
163 Err(_) => "unknown-time".to_string(),
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use client::{Client, UserStore};
171 use clock::FakeSystemClock;
172 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
173 use indoc::indoc;
174 use language::{Anchor, Point};
175 use project::{FakeFs, Project};
176 use serde_json::json;
177 use settings::SettingsStore;
178 use std::path::Path;
179
180 #[gpui::test]
181 async fn test_capture_example(cx: &mut TestAppContext) {
182 init_test(cx);
183 let fs = FakeFs::new(cx.executor());
184
185 let committed_contents = indoc! {"
186 fn main() {
187 one();
188 two();
189 three();
190 four();
191 five();
192 six();
193 seven();
194 eight();
195 nine();
196 }
197 "};
198
199 let disk_contents = indoc! {"
200 fn main() {
201 // comment 1
202 one();
203 two();
204 three();
205 four();
206 five();
207 six();
208 seven();
209 eight();
210 // comment 2
211 nine();
212 }
213 "};
214
215 fs.insert_tree(
216 "/project",
217 json!({
218 ".git": {},
219 "src": {
220 "main.rs": disk_contents,
221 }
222 }),
223 )
224 .await;
225
226 fs.set_head_for_repo(
227 Path::new("/project/.git"),
228 &[("src/main.rs", committed_contents.to_string())],
229 "abc123def456",
230 );
231 fs.set_remote_for_repo(
232 Path::new("/project/.git"),
233 "origin",
234 "https://github.com/test/repo.git",
235 );
236
237 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
238
239 let buffer = project
240 .update(cx, |project, cx| {
241 project.open_local_buffer("/project/src/main.rs", cx)
242 })
243 .await
244 .unwrap();
245
246 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
247 ep_store.update(cx, |ep_store, cx| {
248 ep_store.register_buffer(&buffer, &project, cx)
249 });
250 cx.run_until_parked();
251
252 buffer.update(cx, |buffer, cx| {
253 let point = Point::new(6, 0);
254 buffer.edit([(point..point, " // comment 3\n")], None, cx);
255 let point = Point::new(4, 0);
256 buffer.edit([(point..point, " // comment 4\n")], None, cx);
257
258 pretty_assertions::assert_eq!(
259 buffer.text(),
260 indoc! {"
261 fn main() {
262 // comment 1
263 one();
264 two();
265 // comment 4
266 three();
267 four();
268 // comment 3
269 five();
270 six();
271 seven();
272 eight();
273 // comment 2
274 nine();
275 }
276 "}
277 );
278 });
279 cx.run_until_parked();
280
281 let mut example = cx
282 .update(|cx| capture_example(project.clone(), buffer.clone(), Anchor::MIN, cx).unwrap())
283 .await
284 .unwrap();
285 example.name = "test".to_string();
286
287 pretty_assertions::assert_eq!(
288 example,
289 ExampleSpec {
290 name: "test".to_string(),
291 repository_url: "https://github.com/test/repo.git".to_string(),
292 revision: "abc123def456".to_string(),
293 uncommitted_diff: indoc! {"
294 --- a/project/src/main.rs
295 +++ b/project/src/main.rs
296 @@ -1,4 +1,5 @@
297 fn main() {
298 + // comment 1
299 one();
300 two();
301 three();
302 @@ -7,5 +8,6 @@
303 six();
304 seven();
305 eight();
306 + // comment 2
307 nine();
308 }
309 "}
310 .to_string(),
311 cursor_path: Path::new("project/src/main.rs").into(),
312 cursor_position: indoc! {"
313 <|user_cursor|>fn main() {
314 // comment 1
315 one();
316 two();
317 // comment 4
318 three();
319 four();
320 // comment 3
321 five();
322 six();
323 seven();
324 eight();
325 // comment 2
326 nine();
327 }
328 "}
329 .to_string(),
330 edit_history: indoc! {"
331 --- a/project/src/main.rs
332 +++ b/project/src/main.rs
333 @@ -2,8 +2,10 @@
334 // comment 1
335 one();
336 two();
337 + // comment 4
338 three();
339 four();
340 + // comment 3
341 five();
342 six();
343 seven();
344 "}
345 .to_string(),
346 expected_patch: "".to_string(),
347 }
348 );
349 }
350
351 fn init_test(cx: &mut TestAppContext) {
352 cx.update(|cx| {
353 let settings_store = SettingsStore::test(cx);
354 cx.set_global(settings_store);
355 zlog::init_test();
356 let http_client = FakeHttpClient::with_404_response();
357 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
358 language_model::init(client.clone(), cx);
359 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
360 EditPredictionStore::global(&client, &user_store, cx);
361 })
362 }
363}