1use crate::{
2 EditPredictionStore, StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
4};
5use anyhow::Result;
6use buffer_diff::BufferDiffSnapshot;
7use collections::HashMap;
8use gpui::{App, Entity, Task};
9use language::{Buffer, ToPoint as _};
10use project::{Project, WorktreeId};
11use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
12use text::BufferSnapshot as TextBufferSnapshot;
13
14pub fn capture_example(
15 project: Entity<Project>,
16 buffer: Entity<Buffer>,
17 cursor_anchor: language::Anchor,
18 cx: &mut App,
19) -> Option<Task<Result<ExampleSpec>>> {
20 let ep_store = EditPredictionStore::try_global(cx)?;
21 let snapshot = buffer.read(cx).snapshot();
22 let file = snapshot.file()?;
23 let worktree_id = file.worktree_id(cx);
24 let repository = project.read(cx).active_repository(cx)?;
25 let repository_snapshot = repository.read(cx).snapshot();
26 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
27 let cursor_path = worktree.read(cx).root_name().join(file.path());
28 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
29 return None;
30 }
31
32 let repository_url = repository_snapshot
33 .remote_origin_url
34 .clone()
35 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
36 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
37
38 let mut events = ep_store.update(cx, |store, cx| {
39 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
40 });
41
42 let git_store = project.read(cx).git_store().clone();
43
44 Some(cx.spawn(async move |mut cx| {
45 let snapshots_by_path =
46 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
47
48 events.retain(|stored_event| {
49 match stored_event.event.as_ref() {
50 zeta_prompt::Event::BufferChange { path, .. } => {
51 if !snapshots_by_path.contains_key(path) {
52 return false;
53 }
54 }
55 }
56 true
57 });
58
59 let line_comment_prefix = snapshot
60 .language()
61 .and_then(|lang| lang.config().line_comments.first())
62 .map(|s| s.to_string())
63 .unwrap_or_default();
64 let (cursor_excerpt, cursor_offset) = cx
65 .background_executor()
66 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
67 .await;
68 let uncommitted_diff = cx
69 .background_executor()
70 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
71 .await;
72
73 let mut edit_history = String::new();
74 for stored_event in &events {
75 zeta_prompt::write_event(&mut edit_history, &stored_event.event);
76 if !edit_history.ends_with('\n') {
77 edit_history.push('\n');
78 }
79 }
80
81 let mut spec = ExampleSpec {
82 name: generate_timestamp_name(),
83 repository_url,
84 revision,
85 tags: Vec::new(),
86 reasoning: None,
87 uncommitted_diff,
88 cursor_path: cursor_path.as_std_path().into(),
89 cursor_position: String::new(),
90 edit_history,
91 expected_patches: Vec::new(),
92 };
93 spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
94 Ok(spec)
95 }))
96}
97
98fn compute_cursor_excerpt(
99 snapshot: &language::BufferSnapshot,
100 cursor_anchor: language::Anchor,
101) -> (String, usize) {
102 use text::ToOffset as _;
103
104 let cursor_point = cursor_anchor.to_point(snapshot);
105 let (_editable_range, context_range) =
106 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
107 let context_start_offset = context_range.start.to_offset(snapshot);
108 let cursor_offset = cursor_anchor.to_offset(snapshot);
109 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
110 let excerpt = snapshot.text_for_range(context_range).collect::<String>();
111 (excerpt, cursor_offset_in_excerpt)
112}
113
114async fn collect_snapshots(
115 project: &Entity<Project>,
116 git_store: &Entity<project::git_store::GitStore>,
117 worktree_id: WorktreeId,
118 events: &[StoredEvent],
119 cx: &mut gpui::AsyncApp,
120) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
121 let mut snapshots_by_path = HashMap::default();
122 let root_name = project.read_with(cx, |project, cx| {
123 project
124 .worktree_for_id(worktree_id, cx)
125 .unwrap()
126 .read(cx)
127 .root_name()
128 .to_owned()
129 })?;
130 for stored_event in events {
131 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
132 if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
133 let project_path = project
134 .find_project_path(path, cx)
135 .filter(|path| path.worktree_id == worktree_id)?;
136 let full_path = root_name.join(&project_path.path).as_std_path().into();
137 Some((project_path, full_path))
138 })? {
139 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
140 let buffer = project
141 .update(cx, |project, cx| {
142 project.open_buffer(project_path.clone(), cx)
143 })?
144 .await?;
145 let diff = git_store
146 .update(cx, |git_store, cx| {
147 git_store.open_uncommitted_diff(buffer.clone(), cx)
148 })?
149 .await?;
150 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx))?;
151 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
152 }
153 }
154 }
155 Ok(snapshots_by_path)
156}
157
158fn compute_uncommitted_diff(
159 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
160) -> String {
161 let mut uncommitted_diff = String::new();
162 for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
163 if let Some(head_text) = &diff_snapshot.base_text_string() {
164 let file_diff = language::unified_diff(head_text, &before_text.text());
165 if !file_diff.is_empty() {
166 let path_str = full_path.to_string_lossy();
167 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
168 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
169 uncommitted_diff.push_str(&file_diff);
170 if !uncommitted_diff.ends_with('\n') {
171 uncommitted_diff.push('\n');
172 }
173 }
174 }
175 }
176 uncommitted_diff
177}
178
179fn generate_timestamp_name() -> String {
180 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
181 match format {
182 Ok(format) => {
183 let now = time::OffsetDateTime::now_local()
184 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
185 now.format(&format)
186 .unwrap_or_else(|_| "unknown-time".to_string())
187 }
188 Err(_) => "unknown-time".to_string(),
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use client::{Client, UserStore};
196 use clock::FakeSystemClock;
197 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
198 use indoc::indoc;
199 use language::{Anchor, Point};
200 use project::{FakeFs, Project};
201 use serde_json::json;
202 use settings::SettingsStore;
203 use std::path::Path;
204
205 #[gpui::test]
206 async fn test_capture_example(cx: &mut TestAppContext) {
207 init_test(cx);
208 let fs = FakeFs::new(cx.executor());
209
210 let committed_contents = indoc! {"
211 fn main() {
212 one();
213 two();
214 three();
215 four();
216 five();
217 six();
218 seven();
219 eight();
220 nine();
221 }
222 "};
223
224 let disk_contents = indoc! {"
225 fn main() {
226 // comment 1
227 one();
228 two();
229 three();
230 four();
231 five();
232 six();
233 seven();
234 eight();
235 // comment 2
236 nine();
237 }
238 "};
239
240 fs.insert_tree(
241 "/project",
242 json!({
243 ".git": {},
244 "src": {
245 "main.rs": disk_contents,
246 }
247 }),
248 )
249 .await;
250
251 fs.set_head_for_repo(
252 Path::new("/project/.git"),
253 &[("src/main.rs", committed_contents.to_string())],
254 "abc123def456",
255 );
256 fs.set_remote_for_repo(
257 Path::new("/project/.git"),
258 "origin",
259 "https://github.com/test/repo.git",
260 );
261
262 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
263
264 let buffer = project
265 .update(cx, |project, cx| {
266 project.open_local_buffer("/project/src/main.rs", cx)
267 })
268 .await
269 .unwrap();
270
271 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
272 ep_store.update(cx, |ep_store, cx| {
273 ep_store.register_buffer(&buffer, &project, cx)
274 });
275 cx.run_until_parked();
276
277 buffer.update(cx, |buffer, cx| {
278 let point = Point::new(6, 0);
279 buffer.edit([(point..point, " // comment 3\n")], None, cx);
280 let point = Point::new(4, 0);
281 buffer.edit([(point..point, " // comment 4\n")], None, cx);
282
283 pretty_assertions::assert_eq!(
284 buffer.text(),
285 indoc! {"
286 fn main() {
287 // comment 1
288 one();
289 two();
290 // comment 4
291 three();
292 four();
293 // comment 3
294 five();
295 six();
296 seven();
297 eight();
298 // comment 2
299 nine();
300 }
301 "}
302 );
303 });
304 cx.run_until_parked();
305
306 let mut example = cx
307 .update(|cx| capture_example(project.clone(), buffer.clone(), Anchor::MIN, cx).unwrap())
308 .await
309 .unwrap();
310 example.name = "test".to_string();
311
312 pretty_assertions::assert_eq!(
313 example,
314 ExampleSpec {
315 name: "test".to_string(),
316 repository_url: "https://github.com/test/repo.git".to_string(),
317 revision: "abc123def456".to_string(),
318 tags: Vec::new(),
319 reasoning: None,
320 uncommitted_diff: indoc! {"
321 --- a/project/src/main.rs
322 +++ b/project/src/main.rs
323 @@ -1,4 +1,5 @@
324 fn main() {
325 + // comment 1
326 one();
327 two();
328 three();
329 @@ -7,5 +8,6 @@
330 six();
331 seven();
332 eight();
333 + // comment 2
334 nine();
335 }
336 "}
337 .to_string(),
338 cursor_path: Path::new("project/src/main.rs").into(),
339 cursor_position: indoc! {"
340 fn main() {
341 ^[CURSOR_POSITION]
342 // comment 1
343 one();
344 two();
345 // comment 4
346 three();
347 four();
348 // comment 3
349 five();
350 six();
351 seven();
352 eight();
353 // comment 2
354 nine();
355 }
356 "}
357 .to_string(),
358 edit_history: indoc! {"
359 --- a/project/src/main.rs
360 +++ b/project/src/main.rs
361 @@ -2,8 +2,10 @@
362 // comment 1
363 one();
364 two();
365 + // comment 4
366 three();
367 four();
368 + // comment 3
369 five();
370 six();
371 seven();
372 "}
373 .to_string(),
374 expected_patches: Vec::new()
375 }
376 );
377 }
378
379 fn init_test(cx: &mut TestAppContext) {
380 cx.update(|cx| {
381 let settings_store = SettingsStore::test(cx);
382 cx.set_global(settings_store);
383 zlog::init_test();
384 let http_client = FakeHttpClient::with_404_response();
385 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
386 language_model::init(client.clone(), cx);
387 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
388 EditPredictionStore::global(&client, &user_store, cx);
389 })
390 }
391}