1use client::test::FakeServer;
2use clock::{FakeSystemClock, ReplicaId};
3use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
4use cloud_llm_client::{PredictEditsBody, PredictEditsResponse};
5use gpui::TestAppContext;
6use http_client::FakeHttpClient;
7use indoc::indoc;
8use language::Point;
9use parking_lot::Mutex;
10use serde_json::json;
11use settings::SettingsStore;
12use util::{path, rel_path::rel_path};
13
14use crate::zeta1::MAX_EVENT_TOKENS;
15
16use super::*;
17
18const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
19
20#[gpui::test]
21async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
22 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
23 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
24 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
25 });
26
27 let edit_preview = cx
28 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
29 .await;
30
31 let completion = EditPrediction {
32 edits,
33 edit_preview,
34 buffer: buffer.clone(),
35 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
36 id: EditPredictionId("the-id".into()),
37 inputs: EditPredictionInputs {
38 events: Default::default(),
39 included_files: Default::default(),
40 cursor_point: cloud_llm_client::predict_edits_v3::Point {
41 line: Line(0),
42 column: 0,
43 },
44 cursor_path: Path::new("").into(),
45 },
46 buffer_snapshotted_at: Instant::now(),
47 response_received_at: Instant::now(),
48 };
49
50 cx.update(|cx| {
51 assert_eq!(
52 from_completion_edits(
53 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
54 &buffer,
55 cx
56 ),
57 vec![(2..5, "REM".into()), (9..11, "".into())]
58 );
59
60 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
61 assert_eq!(
62 from_completion_edits(
63 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
64 &buffer,
65 cx
66 ),
67 vec![(2..2, "REM".into()), (6..8, "".into())]
68 );
69
70 buffer.update(cx, |buffer, cx| buffer.undo(cx));
71 assert_eq!(
72 from_completion_edits(
73 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
74 &buffer,
75 cx
76 ),
77 vec![(2..5, "REM".into()), (9..11, "".into())]
78 );
79
80 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
81 assert_eq!(
82 from_completion_edits(
83 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
84 &buffer,
85 cx
86 ),
87 vec![(3..3, "EM".into()), (7..9, "".into())]
88 );
89
90 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
91 assert_eq!(
92 from_completion_edits(
93 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
94 &buffer,
95 cx
96 ),
97 vec![(4..4, "M".into()), (8..10, "".into())]
98 );
99
100 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
101 assert_eq!(
102 from_completion_edits(
103 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
104 &buffer,
105 cx
106 ),
107 vec![(9..11, "".into())]
108 );
109
110 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
111 assert_eq!(
112 from_completion_edits(
113 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
114 &buffer,
115 cx
116 ),
117 vec![(4..4, "M".into()), (8..10, "".into())]
118 );
119
120 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
121 assert_eq!(
122 from_completion_edits(
123 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
124 &buffer,
125 cx
126 ),
127 vec![(4..4, "M".into())]
128 );
129
130 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
131 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
132 })
133}
134
135#[gpui::test]
136async fn test_clean_up_diff(cx: &mut TestAppContext) {
137 init_test(cx);
138
139 assert_eq!(
140 apply_edit_prediction(
141 indoc! {"
142 fn main() {
143 let word_1 = \"lorem\";
144 let range = word.len()..word.len();
145 }
146 "},
147 indoc! {"
148 <|editable_region_start|>
149 fn main() {
150 let word_1 = \"lorem\";
151 let range = word_1.len()..word_1.len();
152 }
153
154 <|editable_region_end|>
155 "},
156 cx,
157 )
158 .await,
159 indoc! {"
160 fn main() {
161 let word_1 = \"lorem\";
162 let range = word_1.len()..word_1.len();
163 }
164 "},
165 );
166
167 assert_eq!(
168 apply_edit_prediction(
169 indoc! {"
170 fn main() {
171 let story = \"the quick\"
172 }
173 "},
174 indoc! {"
175 <|editable_region_start|>
176 fn main() {
177 let story = \"the quick brown fox jumps over the lazy dog\";
178 }
179
180 <|editable_region_end|>
181 "},
182 cx,
183 )
184 .await,
185 indoc! {"
186 fn main() {
187 let story = \"the quick brown fox jumps over the lazy dog\";
188 }
189 "},
190 );
191}
192
193#[gpui::test]
194async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
195 init_test(cx);
196
197 let buffer_content = "lorem\n";
198 let completion_response = indoc! {"
199 ```animals.js
200 <|start_of_file|>
201 <|editable_region_start|>
202 lorem
203 ipsum
204 <|editable_region_end|>
205 ```"};
206
207 assert_eq!(
208 apply_edit_prediction(buffer_content, completion_response, cx).await,
209 "lorem\nipsum"
210 );
211}
212
213#[gpui::test]
214async fn test_can_collect_data(cx: &mut TestAppContext) {
215 init_test(cx);
216
217 let fs = project::FakeFs::new(cx.executor());
218 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
219 .await;
220
221 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
222 let buffer = project
223 .update(cx, |project, cx| {
224 project.open_local_buffer(path!("/project/src/main.rs"), cx)
225 })
226 .await
227 .unwrap();
228
229 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
230 zeta.update(cx, |zeta, _cx| {
231 zeta.data_collection_choice = DataCollectionChoice::Enabled
232 });
233
234 run_edit_prediction(&buffer, &project, &zeta, cx).await;
235 assert_eq!(
236 captured_request.lock().clone().unwrap().can_collect_data,
237 true
238 );
239
240 zeta.update(cx, |zeta, _cx| {
241 zeta.data_collection_choice = DataCollectionChoice::Disabled
242 });
243
244 run_edit_prediction(&buffer, &project, &zeta, cx).await;
245 assert_eq!(
246 captured_request.lock().clone().unwrap().can_collect_data,
247 false
248 );
249}
250
251#[gpui::test]
252async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
253 init_test(cx);
254
255 let fs = project::FakeFs::new(cx.executor());
256 let project = Project::test(fs.clone(), [], cx).await;
257
258 let buffer = cx.new(|_cx| {
259 Buffer::remote(
260 language::BufferId::new(1).unwrap(),
261 ReplicaId::new(1),
262 language::Capability::ReadWrite,
263 "fn main() {\n println!(\"Hello\");\n}",
264 )
265 });
266
267 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
268 zeta.update(cx, |zeta, _cx| {
269 zeta.data_collection_choice = DataCollectionChoice::Enabled
270 });
271
272 run_edit_prediction(&buffer, &project, &zeta, cx).await;
273 assert_eq!(
274 captured_request.lock().clone().unwrap().can_collect_data,
275 false
276 );
277}
278
279#[gpui::test]
280async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
281 init_test(cx);
282
283 let fs = project::FakeFs::new(cx.executor());
284 fs.insert_tree(
285 path!("/project"),
286 json!({
287 "LICENSE": BSD_0_TXT,
288 ".env": "SECRET_KEY=secret"
289 }),
290 )
291 .await;
292
293 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
294 let buffer = project
295 .update(cx, |project, cx| {
296 project.open_local_buffer("/project/.env", cx)
297 })
298 .await
299 .unwrap();
300
301 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
302 zeta.update(cx, |zeta, _cx| {
303 zeta.data_collection_choice = DataCollectionChoice::Enabled
304 });
305
306 run_edit_prediction(&buffer, &project, &zeta, cx).await;
307 assert_eq!(
308 captured_request.lock().clone().unwrap().can_collect_data,
309 false
310 );
311}
312
313#[gpui::test]
314async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
315 init_test(cx);
316
317 let fs = project::FakeFs::new(cx.executor());
318 let project = Project::test(fs.clone(), [], cx).await;
319 let buffer = cx.new(|cx| Buffer::local("", cx));
320
321 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
322 zeta.update(cx, |zeta, _cx| {
323 zeta.data_collection_choice = DataCollectionChoice::Enabled
324 });
325
326 run_edit_prediction(&buffer, &project, &zeta, cx).await;
327 assert_eq!(
328 captured_request.lock().clone().unwrap().can_collect_data,
329 false
330 );
331}
332
333#[gpui::test]
334async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
335 init_test(cx);
336
337 let fs = project::FakeFs::new(cx.executor());
338 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
339 .await;
340
341 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
342 let buffer = project
343 .update(cx, |project, cx| {
344 project.open_local_buffer("/project/main.rs", cx)
345 })
346 .await
347 .unwrap();
348
349 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
350 zeta.update(cx, |zeta, _cx| {
351 zeta.data_collection_choice = DataCollectionChoice::Enabled
352 });
353
354 run_edit_prediction(&buffer, &project, &zeta, cx).await;
355 assert_eq!(
356 captured_request.lock().clone().unwrap().can_collect_data,
357 false
358 );
359}
360
361#[gpui::test]
362async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
363 init_test(cx);
364
365 let fs = project::FakeFs::new(cx.executor());
366 fs.insert_tree(
367 path!("/open_source_worktree"),
368 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
369 )
370 .await;
371 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
372 .await;
373
374 let project = Project::test(
375 fs.clone(),
376 [
377 path!("/open_source_worktree").as_ref(),
378 path!("/closed_source_worktree").as_ref(),
379 ],
380 cx,
381 )
382 .await;
383 let buffer = project
384 .update(cx, |project, cx| {
385 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
386 })
387 .await
388 .unwrap();
389
390 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
391 zeta.update(cx, |zeta, _cx| {
392 zeta.data_collection_choice = DataCollectionChoice::Enabled
393 });
394
395 run_edit_prediction(&buffer, &project, &zeta, cx).await;
396 assert_eq!(
397 captured_request.lock().clone().unwrap().can_collect_data,
398 true
399 );
400
401 let closed_source_file = project
402 .update(cx, |project, cx| {
403 let worktree2 = project
404 .worktree_for_root_name("closed_source_worktree", cx)
405 .unwrap();
406 worktree2.update(cx, |worktree2, cx| {
407 worktree2.load_file(rel_path("main.rs"), cx)
408 })
409 })
410 .await
411 .unwrap()
412 .file;
413
414 buffer.update(cx, |buffer, cx| {
415 buffer.file_updated(closed_source_file, cx);
416 });
417
418 run_edit_prediction(&buffer, &project, &zeta, cx).await;
419 assert_eq!(
420 captured_request.lock().clone().unwrap().can_collect_data,
421 false
422 );
423}
424
425#[gpui::test]
426async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
427 init_test(cx);
428
429 let fs = project::FakeFs::new(cx.executor());
430 fs.insert_tree(
431 path!("/worktree1"),
432 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
433 )
434 .await;
435 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
436 .await;
437
438 let project = Project::test(
439 fs.clone(),
440 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
441 cx,
442 )
443 .await;
444 let buffer = project
445 .update(cx, |project, cx| {
446 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
447 })
448 .await
449 .unwrap();
450 let private_buffer = project
451 .update(cx, |project, cx| {
452 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
453 })
454 .await
455 .unwrap();
456
457 let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
458 zeta.update(cx, |zeta, _cx| {
459 zeta.data_collection_choice = DataCollectionChoice::Enabled
460 });
461
462 run_edit_prediction(&buffer, &project, &zeta, cx).await;
463 assert_eq!(
464 captured_request.lock().clone().unwrap().can_collect_data,
465 true
466 );
467
468 // this has a side effect of registering the buffer to watch for edits
469 run_edit_prediction(&private_buffer, &project, &zeta, cx).await;
470 assert_eq!(
471 captured_request.lock().clone().unwrap().can_collect_data,
472 false
473 );
474
475 private_buffer.update(cx, |private_buffer, cx| {
476 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
477 });
478
479 run_edit_prediction(&buffer, &project, &zeta, cx).await;
480 assert_eq!(
481 captured_request.lock().clone().unwrap().can_collect_data,
482 false
483 );
484
485 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
486 // included
487 buffer.update(cx, |buffer, cx| {
488 buffer.edit(
489 [(
490 0..0,
491 " ".repeat(MAX_EVENT_TOKENS * zeta1::BYTES_PER_TOKEN_GUESS),
492 )],
493 None,
494 cx,
495 );
496 });
497
498 run_edit_prediction(&buffer, &project, &zeta, cx).await;
499 assert_eq!(
500 captured_request.lock().clone().unwrap().can_collect_data,
501 true
502 );
503}
504
505fn init_test(cx: &mut TestAppContext) {
506 cx.update(|cx| {
507 let settings_store = SettingsStore::test(cx);
508 cx.set_global(settings_store);
509 });
510}
511
512async fn apply_edit_prediction(
513 buffer_content: &str,
514 completion_response: &str,
515 cx: &mut TestAppContext,
516) -> String {
517 let fs = project::FakeFs::new(cx.executor());
518 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
519 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
520 let (zeta, _, response) = make_test_zeta(&project, cx).await;
521 *response.lock() = completion_response.to_string();
522 let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await;
523 buffer.update(cx, |buffer, cx| {
524 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
525 });
526 buffer.read_with(cx, |buffer, _| buffer.text())
527}
528
529async fn run_edit_prediction(
530 buffer: &Entity<Buffer>,
531 project: &Entity<Project>,
532 zeta: &Entity<Zeta>,
533 cx: &mut TestAppContext,
534) -> EditPrediction {
535 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
536 zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx));
537 cx.background_executor.run_until_parked();
538 let prediction_task = zeta.update(cx, |zeta, cx| {
539 zeta.request_prediction(&project, buffer, cursor, Default::default(), cx)
540 });
541 prediction_task.await.unwrap().unwrap().prediction.unwrap()
542}
543
544async fn make_test_zeta(
545 project: &Entity<Project>,
546 cx: &mut TestAppContext,
547) -> (
548 Entity<Zeta>,
549 Arc<Mutex<Option<PredictEditsBody>>>,
550 Arc<Mutex<String>>,
551) {
552 let default_response = indoc! {"
553 ```main.rs
554 <|start_of_file|>
555 <|editable_region_start|>
556 hello world
557 <|editable_region_end|>
558 ```"
559 };
560 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
561 let completion_response: Arc<Mutex<String>> =
562 Arc::new(Mutex::new(default_response.to_string()));
563 let http_client = FakeHttpClient::create({
564 let captured_request = captured_request.clone();
565 let completion_response = completion_response.clone();
566 let mut next_request_id = 0;
567 move |req| {
568 let captured_request = captured_request.clone();
569 let completion_response = completion_response.clone();
570 async move {
571 match (req.method(), req.uri().path()) {
572 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
573 .status(200)
574 .body(
575 serde_json::to_string(&CreateLlmTokenResponse {
576 token: LlmToken("the-llm-token".to_string()),
577 })
578 .unwrap()
579 .into(),
580 )
581 .unwrap()),
582 (&Method::POST, "/predict_edits/v2") => {
583 let mut request_body = String::new();
584 req.into_body().read_to_string(&mut request_body).await?;
585 *captured_request.lock() =
586 Some(serde_json::from_str(&request_body).unwrap());
587 next_request_id += 1;
588 Ok(http_client::Response::builder()
589 .status(200)
590 .body(
591 serde_json::to_string(&PredictEditsResponse {
592 request_id: format!("request-{next_request_id}"),
593 output_excerpt: completion_response.lock().clone(),
594 })
595 .unwrap()
596 .into(),
597 )
598 .unwrap())
599 }
600 _ => Ok(http_client::Response::builder()
601 .status(404)
602 .body("Not Found".into())
603 .unwrap()),
604 }
605 }
606 }
607 });
608
609 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
610 cx.update(|cx| {
611 RefreshLlmTokenListener::register(client.clone(), cx);
612 });
613 let _server = FakeServer::for_client(42, &client, cx).await;
614
615 let zeta = cx.new(|cx| {
616 let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx);
617 zeta.set_edit_prediction_model(ZetaEditPredictionModel::Zeta1);
618
619 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
620 for worktree in worktrees {
621 let worktree_id = worktree.read(cx).id();
622 zeta.get_or_init_zeta_project(project, cx)
623 .license_detection_watchers
624 .entry(worktree_id)
625 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
626 }
627
628 zeta
629 });
630
631 (zeta, captured_request, completion_response)
632}
633
634fn to_completion_edits(
635 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
636 buffer: &Entity<Buffer>,
637 cx: &App,
638) -> Vec<(Range<Anchor>, Arc<str>)> {
639 let buffer = buffer.read(cx);
640 iterator
641 .into_iter()
642 .map(|(range, text)| {
643 (
644 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
645 text,
646 )
647 })
648 .collect()
649}
650
651fn from_completion_edits(
652 editor_edits: &[(Range<Anchor>, Arc<str>)],
653 buffer: &Entity<Buffer>,
654 cx: &App,
655) -> Vec<(Range<usize>, Arc<str>)> {
656 let buffer = buffer.read(cx);
657 editor_edits
658 .iter()
659 .map(|(range, text)| {
660 (
661 range.start.to_offset(buffer)..range.end.to_offset(buffer),
662 text.clone(),
663 )
664 })
665 .collect()
666}
667
668#[ctor::ctor]
669fn init_logger() {
670 zlog::init_test();
671}