1use crate::{
2 EditPredictionExampleCaptureFeatureFlag, EditPredictionStore, StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
4};
5use anyhow::Result;
6use buffer_diff::BufferDiffSnapshot;
7use collections::HashMap;
8use feature_flags::FeatureFlagAppExt as _;
9use gpui::{App, Entity, Task};
10use language::{Buffer, ToPoint as _};
11use project::{Project, WorktreeId};
12use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
13use text::BufferSnapshot as TextBufferSnapshot;
14
15pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
16
17pub fn capture_example(
18 project: Entity<Project>,
19 buffer: Entity<Buffer>,
20 cursor_anchor: language::Anchor,
21 cx: &mut App,
22) -> Option<Task<Result<ExampleSpec>>> {
23 let ep_store = EditPredictionStore::try_global(cx)?;
24 let snapshot = buffer.read(cx).snapshot();
25 let file = snapshot.file()?;
26 let worktree_id = file.worktree_id(cx);
27 let repository = project.read(cx).active_repository(cx)?;
28 let repository_snapshot = repository.read(cx).snapshot();
29 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
30 let cursor_path = worktree.read(cx).root_name().join(file.path());
31 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
32 return None;
33 }
34
35 let repository_url = repository_snapshot
36 .remote_origin_url
37 .clone()
38 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
39 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
40
41 let mut events = ep_store.update(cx, |store, cx| {
42 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
43 });
44
45 let git_store = project.read(cx).git_store().clone();
46
47 Some(cx.spawn(async move |mut cx| {
48 let snapshots_by_path =
49 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
50
51 events.retain(|stored_event| {
52 match stored_event.event.as_ref() {
53 zeta_prompt::Event::BufferChange { path, .. } => {
54 if !snapshots_by_path.contains_key(path) {
55 return false;
56 }
57 }
58 }
59 true
60 });
61
62 let line_comment_prefix = snapshot
63 .language()
64 .and_then(|lang| lang.config().line_comments.first())
65 .map(|s| s.to_string())
66 .unwrap_or_default();
67 let (cursor_excerpt, cursor_offset) = cx
68 .background_executor()
69 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
70 .await;
71 let uncommitted_diff = cx
72 .background_executor()
73 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
74 .await;
75
76 let mut edit_history = String::new();
77 for stored_event in &events {
78 zeta_prompt::write_event(&mut edit_history, &stored_event.event);
79 if !edit_history.ends_with('\n') {
80 edit_history.push('\n');
81 }
82 }
83
84 let mut spec = ExampleSpec {
85 name: generate_timestamp_name(),
86 repository_url,
87 revision,
88 tags: Vec::new(),
89 reasoning: None,
90 uncommitted_diff,
91 cursor_path: cursor_path.as_std_path().into(),
92 cursor_position: String::new(),
93 edit_history,
94 expected_patches: Vec::new(),
95 };
96 spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
97 Ok(spec)
98 }))
99}
100
101fn compute_cursor_excerpt(
102 snapshot: &language::BufferSnapshot,
103 cursor_anchor: language::Anchor,
104) -> (String, usize) {
105 use text::ToOffset as _;
106
107 let cursor_point = cursor_anchor.to_point(snapshot);
108 let (_editable_range, context_range) =
109 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
110 let context_start_offset = context_range.start.to_offset(snapshot);
111 let cursor_offset = cursor_anchor.to_offset(snapshot);
112 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
113 let excerpt = snapshot.text_for_range(context_range).collect::<String>();
114 (excerpt, cursor_offset_in_excerpt)
115}
116
117async fn collect_snapshots(
118 project: &Entity<Project>,
119 git_store: &Entity<project::git_store::GitStore>,
120 worktree_id: WorktreeId,
121 events: &[StoredEvent],
122 cx: &mut gpui::AsyncApp,
123) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
124 let mut snapshots_by_path = HashMap::default();
125 let root_name = project.read_with(cx, |project, cx| {
126 project
127 .worktree_for_id(worktree_id, cx)
128 .unwrap()
129 .read(cx)
130 .root_name()
131 .to_owned()
132 })?;
133 for stored_event in events {
134 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
135 if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
136 let project_path = project
137 .find_project_path(path, cx)
138 .filter(|path| path.worktree_id == worktree_id)?;
139 let full_path = root_name.join(&project_path.path).as_std_path().into();
140 Some((project_path, full_path))
141 })? {
142 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
143 let buffer = project
144 .update(cx, |project, cx| {
145 project.open_buffer(project_path.clone(), cx)
146 })?
147 .await?;
148 let diff = git_store
149 .update(cx, |git_store, cx| {
150 git_store.open_uncommitted_diff(buffer.clone(), cx)
151 })?
152 .await?;
153 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx))?;
154 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
155 }
156 }
157 }
158 Ok(snapshots_by_path)
159}
160
161fn compute_uncommitted_diff(
162 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
163) -> String {
164 let mut uncommitted_diff = String::new();
165 for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
166 if let Some(head_text) = &diff_snapshot.base_text_string() {
167 let file_diff = language::unified_diff(head_text, &before_text.text());
168 if !file_diff.is_empty() {
169 let path_str = full_path.to_string_lossy();
170 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
171 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
172 uncommitted_diff.push_str(&file_diff);
173 if !uncommitted_diff.ends_with('\n') {
174 uncommitted_diff.push('\n');
175 }
176 }
177 }
178 }
179 uncommitted_diff
180}
181
182fn generate_timestamp_name() -> String {
183 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
184 match format {
185 Ok(format) => {
186 let now = time::OffsetDateTime::now_local()
187 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
188 now.format(&format)
189 .unwrap_or_else(|_| "unknown-time".to_string())
190 }
191 Err(_) => "unknown-time".to_string(),
192 }
193}
194
195pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
196 let capture_rate = language::language_settings::all_language_settings(None, cx)
197 .edit_predictions
198 .example_capture_rate
199 .unwrap_or(DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS);
200 cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
201 && rand::random::<u16>() % 10_000 < capture_rate
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use client::{Client, UserStore};
208 use clock::FakeSystemClock;
209 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
210 use indoc::indoc;
211 use language::{Anchor, Point};
212 use project::{FakeFs, Project};
213 use serde_json::json;
214 use settings::SettingsStore;
215 use std::path::Path;
216
217 #[gpui::test]
218 async fn test_capture_example(cx: &mut TestAppContext) {
219 init_test(cx);
220 let fs = FakeFs::new(cx.executor());
221
222 let committed_contents = indoc! {"
223 fn main() {
224 one();
225 two();
226 three();
227 four();
228 five();
229 six();
230 seven();
231 eight();
232 nine();
233 }
234 "};
235
236 let disk_contents = indoc! {"
237 fn main() {
238 // comment 1
239 one();
240 two();
241 three();
242 four();
243 five();
244 six();
245 seven();
246 eight();
247 // comment 2
248 nine();
249 }
250 "};
251
252 fs.insert_tree(
253 "/project",
254 json!({
255 ".git": {},
256 "src": {
257 "main.rs": disk_contents,
258 }
259 }),
260 )
261 .await;
262
263 fs.set_head_for_repo(
264 Path::new("/project/.git"),
265 &[("src/main.rs", committed_contents.to_string())],
266 "abc123def456",
267 );
268 fs.set_remote_for_repo(
269 Path::new("/project/.git"),
270 "origin",
271 "https://github.com/test/repo.git",
272 );
273
274 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
275
276 let buffer = project
277 .update(cx, |project, cx| {
278 project.open_local_buffer("/project/src/main.rs", cx)
279 })
280 .await
281 .unwrap();
282
283 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
284 ep_store.update(cx, |ep_store, cx| {
285 ep_store.register_buffer(&buffer, &project, cx)
286 });
287 cx.run_until_parked();
288
289 buffer.update(cx, |buffer, cx| {
290 let point = Point::new(6, 0);
291 buffer.edit([(point..point, " // comment 3\n")], None, cx);
292 let point = Point::new(4, 0);
293 buffer.edit([(point..point, " // comment 4\n")], None, cx);
294
295 pretty_assertions::assert_eq!(
296 buffer.text(),
297 indoc! {"
298 fn main() {
299 // comment 1
300 one();
301 two();
302 // comment 4
303 three();
304 four();
305 // comment 3
306 five();
307 six();
308 seven();
309 eight();
310 // comment 2
311 nine();
312 }
313 "}
314 );
315 });
316 cx.run_until_parked();
317
318 let mut example = cx
319 .update(|cx| capture_example(project.clone(), buffer.clone(), Anchor::MIN, cx).unwrap())
320 .await
321 .unwrap();
322 example.name = "test".to_string();
323
324 pretty_assertions::assert_eq!(
325 example,
326 ExampleSpec {
327 name: "test".to_string(),
328 repository_url: "https://github.com/test/repo.git".to_string(),
329 revision: "abc123def456".to_string(),
330 tags: Vec::new(),
331 reasoning: None,
332 uncommitted_diff: indoc! {"
333 --- a/project/src/main.rs
334 +++ b/project/src/main.rs
335 @@ -1,4 +1,5 @@
336 fn main() {
337 + // comment 1
338 one();
339 two();
340 three();
341 @@ -7,5 +8,6 @@
342 six();
343 seven();
344 eight();
345 + // comment 2
346 nine();
347 }
348 "}
349 .to_string(),
350 cursor_path: Path::new("project/src/main.rs").into(),
351 cursor_position: indoc! {"
352 fn main() {
353 ^[CURSOR_POSITION]
354 // comment 1
355 one();
356 two();
357 // comment 4
358 three();
359 four();
360 // comment 3
361 five();
362 six();
363 seven();
364 eight();
365 // comment 2
366 nine();
367 }
368 "}
369 .to_string(),
370 edit_history: indoc! {"
371 --- a/project/src/main.rs
372 +++ b/project/src/main.rs
373 @@ -2,8 +2,10 @@
374 // comment 1
375 one();
376 two();
377 + // comment 4
378 three();
379 four();
380 + // comment 3
381 five();
382 six();
383 seven();
384 "}
385 .to_string(),
386 expected_patches: Vec::new()
387 }
388 );
389 }
390
391 fn init_test(cx: &mut TestAppContext) {
392 cx.update(|cx| {
393 let settings_store = SettingsStore::test(cx);
394 cx.set_global(settings_store);
395 zlog::init_test();
396 let http_client = FakeHttpClient::with_404_response();
397 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
398 language_model::init(client.clone(), cx);
399 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
400 EditPredictionStore::global(&client, &user_store, cx);
401 })
402 }
403}