1use crate::{
2 EditPredictionExampleCaptureFeatureFlag, StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
4};
5use anyhow::Result;
6use buffer_diff::BufferDiffSnapshot;
7use collections::HashMap;
8use feature_flags::FeatureFlagAppExt as _;
9use gpui::{App, Entity, Task};
10use language::{Buffer, ToPoint as _};
11use project::{Project, WorktreeId};
12use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
13use text::BufferSnapshot as TextBufferSnapshot;
14
15pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
16
17pub fn capture_example(
18 project: Entity<Project>,
19 buffer: Entity<Buffer>,
20 cursor_anchor: language::Anchor,
21 mut events: Vec<StoredEvent>,
22 cx: &mut App,
23) -> Option<Task<Result<ExampleSpec>>> {
24 let snapshot = buffer.read(cx).snapshot();
25 let file = snapshot.file()?;
26 let worktree_id = file.worktree_id(cx);
27 let repository = project.read(cx).active_repository(cx)?;
28 let repository_snapshot = repository.read(cx).snapshot();
29 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
30 let root_name = worktree.read(cx).root_name_str().to_owned();
31 let cursor_path: Arc<Path> = file.path().as_std_path().into();
32 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
33 return None;
34 }
35
36 let repository_url = repository_snapshot
37 .remote_origin_url
38 .clone()
39 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
40 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
41
42 let git_store = project.read(cx).git_store().clone();
43
44 Some(cx.spawn(async move |mut cx| {
45 let snapshots_by_path =
46 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
47
48 events.retain(|stored_event| {
49 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
50 let relative_path = strip_root_name(path, &root_name);
51 snapshots_by_path.contains_key(relative_path)
52 });
53
54 let line_comment_prefix = snapshot
55 .language()
56 .and_then(|lang| lang.config().line_comments.first())
57 .map(|s| s.to_string())
58 .unwrap_or_default();
59 let (cursor_excerpt, cursor_offset) = cx
60 .background_executor()
61 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
62 .await;
63 let uncommitted_diff = cx
64 .background_executor()
65 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
66 .await;
67
68 let mut edit_history = String::new();
69 for stored_event in &events {
70 write_event_with_relative_paths(&mut edit_history, &stored_event.event, &root_name);
71 if !edit_history.ends_with('\n') {
72 edit_history.push('\n');
73 }
74 }
75
76 let mut spec = ExampleSpec {
77 name: generate_timestamp_name(),
78 repository_url,
79 revision,
80 tags: Vec::new(),
81 reasoning: None,
82 uncommitted_diff,
83 cursor_path,
84 cursor_position: String::new(),
85 edit_history,
86 expected_patches: Vec::new(),
87 };
88 spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
89 Ok(spec)
90 }))
91}
92
93fn strip_root_name<'a>(path: &'a Path, root_name: &str) -> &'a Path {
94 path.strip_prefix(root_name).unwrap_or(path)
95}
96
97fn write_event_with_relative_paths(
98 output: &mut String,
99 event: &zeta_prompt::Event,
100 root_name: &str,
101) {
102 fn write_relative_path(output: &mut String, path: &Path, root_name: &str) {
103 for component in strip_root_name(path, root_name).components() {
104 output.push('/');
105 write!(output, "{}", component.as_os_str().to_string_lossy()).ok();
106 }
107 }
108
109 let zeta_prompt::Event::BufferChange {
110 path,
111 old_path,
112 diff,
113 ..
114 } = event;
115
116 output.push_str("--- a");
117 write_relative_path(output, old_path.as_ref(), root_name);
118 output.push_str("\n+++ b");
119 write_relative_path(output, path.as_ref(), root_name);
120 output.push('\n');
121 output.push_str(diff);
122}
123
124fn compute_cursor_excerpt(
125 snapshot: &language::BufferSnapshot,
126 cursor_anchor: language::Anchor,
127) -> (String, usize) {
128 use text::ToOffset as _;
129
130 let cursor_point = cursor_anchor.to_point(snapshot);
131 let (_editable_range, context_range) =
132 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
133 let context_start_offset = context_range.start.to_offset(snapshot);
134 let cursor_offset = cursor_anchor.to_offset(snapshot);
135 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
136 let excerpt = snapshot.text_for_range(context_range).collect::<String>();
137 (excerpt, cursor_offset_in_excerpt)
138}
139
140async fn collect_snapshots(
141 project: &Entity<Project>,
142 git_store: &Entity<project::git_store::GitStore>,
143 worktree_id: WorktreeId,
144 events: &[StoredEvent],
145 cx: &mut gpui::AsyncApp,
146) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
147 let mut snapshots_by_path = HashMap::default();
148 for stored_event in events {
149 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
150 if let Some((project_path, relative_path)) = project.read_with(cx, |project, cx| {
151 let project_path = project
152 .find_project_path(path, cx)
153 .filter(|path| path.worktree_id == worktree_id)?;
154 let relative_path: Arc<Path> = project_path.path.as_std_path().into();
155 Some((project_path, relative_path))
156 }) {
157 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(relative_path) {
158 let buffer = project
159 .update(cx, |project, cx| {
160 project.open_buffer(project_path.clone(), cx)
161 })
162 .await?;
163 let diff = git_store
164 .update(cx, |git_store, cx| {
165 git_store.open_uncommitted_diff(buffer.clone(), cx)
166 })
167 .await?;
168 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx));
169 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
170 }
171 }
172 }
173 Ok(snapshots_by_path)
174}
175
176fn compute_uncommitted_diff(
177 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
178) -> String {
179 let mut uncommitted_diff = String::new();
180 for (relative_path, (before_text, diff_snapshot)) in snapshots_by_path {
181 if let Some(head_text) = &diff_snapshot.base_text_string() {
182 let file_diff = language::unified_diff(head_text, &before_text.text());
183 if !file_diff.is_empty() {
184 let path_str = relative_path.to_string_lossy();
185 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
186 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
187 uncommitted_diff.push_str(&file_diff);
188 if !uncommitted_diff.ends_with('\n') {
189 uncommitted_diff.push('\n');
190 }
191 }
192 }
193 }
194 uncommitted_diff
195}
196
197fn generate_timestamp_name() -> String {
198 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
199 match format {
200 Ok(format) => {
201 let now = time::OffsetDateTime::now_local()
202 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
203 now.format(&format)
204 .unwrap_or_else(|_| "unknown-time".to_string())
205 }
206 Err(_) => "unknown-time".to_string(),
207 }
208}
209
210pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
211 let capture_rate = language::language_settings::all_language_settings(None, cx)
212 .edit_predictions
213 .example_capture_rate
214 .unwrap_or(DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS);
215 cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
216 && rand::random::<u16>() % 10_000 < capture_rate
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use crate::EditPredictionStore;
223 use client::{Client, UserStore};
224 use clock::FakeSystemClock;
225 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
226 use indoc::indoc;
227 use language::{Anchor, Point};
228 use project::{FakeFs, Project};
229 use serde_json::json;
230 use settings::SettingsStore;
231 use std::path::Path;
232
233 #[gpui::test]
234 async fn test_capture_example(cx: &mut TestAppContext) {
235 init_test(cx);
236 let fs = FakeFs::new(cx.executor());
237
238 let committed_contents = indoc! {"
239 fn main() {
240 one();
241 two();
242 three();
243 four();
244 five();
245 six();
246 seven();
247 eight();
248 nine();
249 }
250 "};
251
252 let disk_contents = indoc! {"
253 fn main() {
254 // comment 1
255 one();
256 two();
257 three();
258 four();
259 five();
260 six();
261 seven();
262 eight();
263 // comment 2
264 nine();
265 }
266 "};
267
268 fs.insert_tree(
269 "/project",
270 json!({
271 ".git": {},
272 "src": {
273 "main.rs": disk_contents,
274 }
275 }),
276 )
277 .await;
278
279 // Create an external file outside the main project
280 fs.insert_tree(
281 "/external",
282 json!({
283 "external.rs": "fn external() {}\n",
284 }),
285 )
286 .await;
287
288 fs.set_head_for_repo(
289 Path::new("/project/.git"),
290 &[("src/main.rs", committed_contents.to_string())],
291 "abc123def456",
292 );
293 fs.set_remote_for_repo(
294 Path::new("/project/.git"),
295 "origin",
296 "https://github.com/test/repo.git",
297 );
298
299 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
300
301 let buffer = project
302 .update(cx, |project, cx| {
303 project.open_local_buffer("/project/src/main.rs", cx)
304 })
305 .await
306 .unwrap();
307
308 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
309 ep_store.update(cx, |ep_store, cx| {
310 ep_store.register_buffer(&buffer, &project, cx)
311 });
312 cx.run_until_parked();
313
314 buffer.update(cx, |buffer, cx| {
315 let point = Point::new(6, 0);
316 buffer.edit([(point..point, " // comment 3\n")], None, cx);
317 let point = Point::new(4, 0);
318 buffer.edit([(point..point, " // comment 4\n")], None, cx);
319
320 pretty_assertions::assert_eq!(
321 buffer.text(),
322 indoc! {"
323 fn main() {
324 // comment 1
325 one();
326 two();
327 // comment 4
328 three();
329 four();
330 // comment 3
331 five();
332 six();
333 seven();
334 eight();
335 // comment 2
336 nine();
337 }
338 "}
339 );
340 });
341 cx.run_until_parked();
342
343 // Open and edit an external file (outside the main project's worktree)
344 let external_buffer = project
345 .update(cx, |project, cx| {
346 project.open_local_buffer("/external/external.rs", cx)
347 })
348 .await
349 .unwrap();
350 ep_store.update(cx, |ep_store, cx| {
351 ep_store.register_buffer(&external_buffer, &project, cx)
352 });
353 cx.run_until_parked();
354 external_buffer.update(cx, |buffer, cx| {
355 let point = Point::new(0, 0);
356 buffer.edit([(point..point, "// external edit\n")], None, cx);
357 });
358 cx.run_until_parked();
359
360 // Verify the external edit was recorded in events
361 let events = ep_store.update(cx, |store, cx| {
362 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
363 });
364 assert!(
365 matches!(
366 events
367 .last()
368 .unwrap()
369 .event
370 .as_ref(),
371 zeta_prompt::Event::BufferChange { path, .. } if path.as_ref() == "/external/external.rs"
372 ),
373 "external file edit should be in events"
374 );
375
376 let mut example = cx
377 .update(|cx| {
378 capture_example(project.clone(), buffer.clone(), Anchor::MIN, events, cx).unwrap()
379 })
380 .await
381 .unwrap();
382 example.name = "test".to_string();
383
384 pretty_assertions::assert_eq!(
385 example,
386 ExampleSpec {
387 name: "test".to_string(),
388 repository_url: "https://github.com/test/repo.git".to_string(),
389 revision: "abc123def456".to_string(),
390 tags: Vec::new(),
391 reasoning: None,
392 uncommitted_diff: indoc! {"
393 --- a/src/main.rs
394 +++ b/src/main.rs
395 @@ -1,4 +1,5 @@
396 fn main() {
397 + // comment 1
398 one();
399 two();
400 three();
401 @@ -7,5 +8,6 @@
402 six();
403 seven();
404 eight();
405 + // comment 2
406 nine();
407 }
408 "}
409 .to_string(),
410 cursor_path: Path::new("src/main.rs").into(),
411 cursor_position: indoc! {"
412 fn main() {
413 ^[CURSOR_POSITION]
414 // comment 1
415 one();
416 two();
417 // comment 4
418 three();
419 four();
420 // comment 3
421 five();
422 six();
423 seven();
424 eight();
425 // comment 2
426 nine();
427 }
428 "}
429 .to_string(),
430 edit_history: indoc! {"
431 --- a/src/main.rs
432 +++ b/src/main.rs
433 @@ -2,8 +2,10 @@
434 // comment 1
435 one();
436 two();
437 + // comment 4
438 three();
439 four();
440 + // comment 3
441 five();
442 six();
443 seven();
444 "}
445 .to_string(),
446 expected_patches: Vec::new()
447 }
448 );
449 }
450
451 fn init_test(cx: &mut TestAppContext) {
452 cx.update(|cx| {
453 let settings_store = SettingsStore::test(cx);
454 cx.set_global(settings_store);
455 zlog::init_test();
456 let http_client = FakeHttpClient::with_404_response();
457 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
458 language_model::init(client.clone(), cx);
459 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
460 EditPredictionStore::global(&client, &user_store, cx);
461 })
462 }
463}