1use crate::{
2 EditPredictionExampleCaptureFeatureFlag, StoredEvent,
3 cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
4};
5use anyhow::Result;
6use buffer_diff::BufferDiffSnapshot;
7use collections::HashMap;
8use feature_flags::FeatureFlagAppExt as _;
9use gpui::{App, Entity, Task};
10use language::{Buffer, ToPoint as _};
11use project::{Project, WorktreeId};
12use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
13use text::BufferSnapshot as TextBufferSnapshot;
14
15pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
16
17pub fn capture_example(
18 project: Entity<Project>,
19 buffer: Entity<Buffer>,
20 cursor_anchor: language::Anchor,
21 mut events: Vec<StoredEvent>,
22 cx: &mut App,
23) -> Option<Task<Result<ExampleSpec>>> {
24 let snapshot = buffer.read(cx).snapshot();
25 let file = snapshot.file()?;
26 let worktree_id = file.worktree_id(cx);
27 let repository = project.read(cx).active_repository(cx)?;
28 let repository_snapshot = repository.read(cx).snapshot();
29 let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
30 let cursor_path = worktree.read(cx).root_name().join(file.path());
31 if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
32 return None;
33 }
34
35 let repository_url = repository_snapshot
36 .remote_origin_url
37 .clone()
38 .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
39 let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
40
41 let git_store = project.read(cx).git_store().clone();
42
43 Some(cx.spawn(async move |mut cx| {
44 let snapshots_by_path =
45 collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
46
47 events.retain(|stored_event| {
48 match stored_event.event.as_ref() {
49 zeta_prompt::Event::BufferChange { path, .. } => {
50 if !snapshots_by_path.contains_key(path) {
51 return false;
52 }
53 }
54 }
55 true
56 });
57
58 let line_comment_prefix = snapshot
59 .language()
60 .and_then(|lang| lang.config().line_comments.first())
61 .map(|s| s.to_string())
62 .unwrap_or_default();
63 let (cursor_excerpt, cursor_offset) = cx
64 .background_executor()
65 .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
66 .await;
67 let uncommitted_diff = cx
68 .background_executor()
69 .spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
70 .await;
71
72 let mut edit_history = String::new();
73 for stored_event in &events {
74 zeta_prompt::write_event(&mut edit_history, &stored_event.event);
75 if !edit_history.ends_with('\n') {
76 edit_history.push('\n');
77 }
78 }
79
80 let mut spec = ExampleSpec {
81 name: generate_timestamp_name(),
82 repository_url,
83 revision,
84 tags: Vec::new(),
85 reasoning: None,
86 uncommitted_diff,
87 cursor_path: cursor_path.as_std_path().into(),
88 cursor_position: String::new(),
89 edit_history,
90 expected_patches: Vec::new(),
91 };
92 spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
93 Ok(spec)
94 }))
95}
96
97fn compute_cursor_excerpt(
98 snapshot: &language::BufferSnapshot,
99 cursor_anchor: language::Anchor,
100) -> (String, usize) {
101 use text::ToOffset as _;
102
103 let cursor_point = cursor_anchor.to_point(snapshot);
104 let (_editable_range, context_range) =
105 editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
106 let context_start_offset = context_range.start.to_offset(snapshot);
107 let cursor_offset = cursor_anchor.to_offset(snapshot);
108 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
109 let excerpt = snapshot.text_for_range(context_range).collect::<String>();
110 (excerpt, cursor_offset_in_excerpt)
111}
112
113async fn collect_snapshots(
114 project: &Entity<Project>,
115 git_store: &Entity<project::git_store::GitStore>,
116 worktree_id: WorktreeId,
117 events: &[StoredEvent],
118 cx: &mut gpui::AsyncApp,
119) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
120 let mut snapshots_by_path = HashMap::default();
121 let root_name = project.read_with(cx, |project, cx| {
122 project
123 .worktree_for_id(worktree_id, cx)
124 .unwrap()
125 .read(cx)
126 .root_name()
127 .to_owned()
128 })?;
129 for stored_event in events {
130 let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
131 if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
132 let project_path = project
133 .find_project_path(path, cx)
134 .filter(|path| path.worktree_id == worktree_id)?;
135 let full_path = root_name.join(&project_path.path).as_std_path().into();
136 Some((project_path, full_path))
137 })? {
138 if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
139 let buffer = project
140 .update(cx, |project, cx| {
141 project.open_buffer(project_path.clone(), cx)
142 })?
143 .await?;
144 let diff = git_store
145 .update(cx, |git_store, cx| {
146 git_store.open_uncommitted_diff(buffer.clone(), cx)
147 })?
148 .await?;
149 let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx))?;
150 entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
151 }
152 }
153 }
154 Ok(snapshots_by_path)
155}
156
157fn compute_uncommitted_diff(
158 snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
159) -> String {
160 let mut uncommitted_diff = String::new();
161 for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
162 if let Some(head_text) = &diff_snapshot.base_text_string() {
163 let file_diff = language::unified_diff(head_text, &before_text.text());
164 if !file_diff.is_empty() {
165 let path_str = full_path.to_string_lossy();
166 writeln!(uncommitted_diff, "--- a/{path_str}").ok();
167 writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
168 uncommitted_diff.push_str(&file_diff);
169 if !uncommitted_diff.ends_with('\n') {
170 uncommitted_diff.push('\n');
171 }
172 }
173 }
174 }
175 uncommitted_diff
176}
177
178fn generate_timestamp_name() -> String {
179 let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
180 match format {
181 Ok(format) => {
182 let now = time::OffsetDateTime::now_local()
183 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
184 now.format(&format)
185 .unwrap_or_else(|_| "unknown-time".to_string())
186 }
187 Err(_) => "unknown-time".to_string(),
188 }
189}
190
191pub(crate) fn should_sample_edit_prediction_example_capture(cx: &App) -> bool {
192 let capture_rate = language::language_settings::all_language_settings(None, cx)
193 .edit_predictions
194 .example_capture_rate
195 .unwrap_or(DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS);
196 cx.has_flag::<EditPredictionExampleCaptureFeatureFlag>()
197 && rand::random::<u16>() % 10_000 < capture_rate
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::EditPredictionStore;
204 use client::{Client, UserStore};
205 use clock::FakeSystemClock;
206 use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
207 use indoc::indoc;
208 use language::{Anchor, Point};
209 use project::{FakeFs, Project};
210 use serde_json::json;
211 use settings::SettingsStore;
212 use std::path::Path;
213
214 #[gpui::test]
215 async fn test_capture_example(cx: &mut TestAppContext) {
216 init_test(cx);
217 let fs = FakeFs::new(cx.executor());
218
219 let committed_contents = indoc! {"
220 fn main() {
221 one();
222 two();
223 three();
224 four();
225 five();
226 six();
227 seven();
228 eight();
229 nine();
230 }
231 "};
232
233 let disk_contents = indoc! {"
234 fn main() {
235 // comment 1
236 one();
237 two();
238 three();
239 four();
240 five();
241 six();
242 seven();
243 eight();
244 // comment 2
245 nine();
246 }
247 "};
248
249 fs.insert_tree(
250 "/project",
251 json!({
252 ".git": {},
253 "src": {
254 "main.rs": disk_contents,
255 }
256 }),
257 )
258 .await;
259
260 fs.set_head_for_repo(
261 Path::new("/project/.git"),
262 &[("src/main.rs", committed_contents.to_string())],
263 "abc123def456",
264 );
265 fs.set_remote_for_repo(
266 Path::new("/project/.git"),
267 "origin",
268 "https://github.com/test/repo.git",
269 );
270
271 let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
272
273 let buffer = project
274 .update(cx, |project, cx| {
275 project.open_local_buffer("/project/src/main.rs", cx)
276 })
277 .await
278 .unwrap();
279
280 let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
281 ep_store.update(cx, |ep_store, cx| {
282 ep_store.register_buffer(&buffer, &project, cx)
283 });
284 cx.run_until_parked();
285
286 buffer.update(cx, |buffer, cx| {
287 let point = Point::new(6, 0);
288 buffer.edit([(point..point, " // comment 3\n")], None, cx);
289 let point = Point::new(4, 0);
290 buffer.edit([(point..point, " // comment 4\n")], None, cx);
291
292 pretty_assertions::assert_eq!(
293 buffer.text(),
294 indoc! {"
295 fn main() {
296 // comment 1
297 one();
298 two();
299 // comment 4
300 three();
301 four();
302 // comment 3
303 five();
304 six();
305 seven();
306 eight();
307 // comment 2
308 nine();
309 }
310 "}
311 );
312 });
313 cx.run_until_parked();
314
315 let events = ep_store.update(cx, |store, cx| {
316 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
317 });
318 let mut example = cx
319 .update(|cx| {
320 capture_example(project.clone(), buffer.clone(), Anchor::MIN, events, cx).unwrap()
321 })
322 .await
323 .unwrap();
324 example.name = "test".to_string();
325
326 pretty_assertions::assert_eq!(
327 example,
328 ExampleSpec {
329 name: "test".to_string(),
330 repository_url: "https://github.com/test/repo.git".to_string(),
331 revision: "abc123def456".to_string(),
332 tags: Vec::new(),
333 reasoning: None,
334 uncommitted_diff: indoc! {"
335 --- a/project/src/main.rs
336 +++ b/project/src/main.rs
337 @@ -1,4 +1,5 @@
338 fn main() {
339 + // comment 1
340 one();
341 two();
342 three();
343 @@ -7,5 +8,6 @@
344 six();
345 seven();
346 eight();
347 + // comment 2
348 nine();
349 }
350 "}
351 .to_string(),
352 cursor_path: Path::new("project/src/main.rs").into(),
353 cursor_position: indoc! {"
354 fn main() {
355 ^[CURSOR_POSITION]
356 // comment 1
357 one();
358 two();
359 // comment 4
360 three();
361 four();
362 // comment 3
363 five();
364 six();
365 seven();
366 eight();
367 // comment 2
368 nine();
369 }
370 "}
371 .to_string(),
372 edit_history: indoc! {"
373 --- a/project/src/main.rs
374 +++ b/project/src/main.rs
375 @@ -2,8 +2,10 @@
376 // comment 1
377 one();
378 two();
379 + // comment 4
380 three();
381 four();
382 + // comment 3
383 five();
384 six();
385 seven();
386 "}
387 .to_string(),
388 expected_patches: Vec::new()
389 }
390 );
391 }
392
393 fn init_test(cx: &mut TestAppContext) {
394 cx.update(|cx| {
395 let settings_store = SettingsStore::test(cx);
396 cx.set_global(settings_store);
397 zlog::init_test();
398 let http_client = FakeHttpClient::with_404_response();
399 let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
400 language_model::init(client.clone(), cx);
401 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
402 EditPredictionStore::global(&client, &user_store, cx);
403 })
404 }
405}