1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
3use client::{UserStore, test::FakeServer};
4use clock::{FakeSystemClock, ReplicaId};
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
8 RejectEditPredictionsBody,
9 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
10};
11use futures::{
12 AsyncReadExt, StreamExt,
13 channel::{mpsc, oneshot},
14};
15use gpui::App;
16use gpui::{
17 Entity, TestAppContext,
18 http_client::{FakeHttpClient, Response},
19};
20use indoc::indoc;
21use language::{Buffer, Point};
22use lsp::LanguageServerId;
23use parking_lot::Mutex;
24use pretty_assertions::{assert_eq, assert_matches};
25use project::{FakeFs, Project};
26use serde_json::json;
27use settings::SettingsStore;
28use std::{path::Path, sync::Arc, time::Duration};
29use util::{path, rel_path::rel_path};
30use uuid::Uuid;
31use zeta_prompt::ZetaPromptInput;
32
33use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
34
35#[gpui::test]
36async fn test_current_state(cx: &mut TestAppContext) {
37 let (ep_store, mut requests) = init_test_with_fake_client(cx);
38 let fs = FakeFs::new(cx.executor());
39 fs.insert_tree(
40 "/root",
41 json!({
42 "1.txt": "Hello!\nHow\nBye\n",
43 "2.txt": "Hola!\nComo\nAdios\n"
44 }),
45 )
46 .await;
47 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
48
49 let buffer1 = project
50 .update(cx, |project, cx| {
51 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
52 project.set_active_path(Some(path.clone()), cx);
53 project.open_buffer(path, cx)
54 })
55 .await
56 .unwrap();
57 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
58 let position = snapshot1.anchor_before(language::Point::new(1, 3));
59
60 ep_store.update(cx, |ep_store, cx| {
61 ep_store.register_project(&project, cx);
62 ep_store.register_buffer(&buffer1, &project, cx);
63 });
64
65 // Prediction for current file
66
67 ep_store.update(cx, |ep_store, cx| {
68 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
69 });
70 let (request, respond_tx) = requests.predict.next().await.unwrap();
71
72 respond_tx
73 .send(model_response(
74 &request,
75 indoc! {r"
76 --- a/root/1.txt
77 +++ b/root/1.txt
78 @@ ... @@
79 Hello!
80 -How
81 +How are you?
82 Bye
83 "},
84 ))
85 .unwrap();
86
87 cx.run_until_parked();
88
89 ep_store.update(cx, |ep_store, cx| {
90 let prediction = ep_store
91 .prediction_at(&buffer1, None, &project, cx)
92 .unwrap();
93 assert_matches!(prediction, BufferEditPrediction::Local { .. });
94 });
95
96 ep_store.update(cx, |ep_store, cx| {
97 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
98 });
99
100 // Prediction for diagnostic in another file
101
102 let diagnostic = lsp::Diagnostic {
103 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
104 severity: Some(lsp::DiagnosticSeverity::ERROR),
105 message: "Sentence is incomplete".to_string(),
106 ..Default::default()
107 };
108
109 project.update(cx, |project, cx| {
110 project.lsp_store().update(cx, |lsp_store, cx| {
111 lsp_store
112 .update_diagnostics(
113 LanguageServerId(0),
114 lsp::PublishDiagnosticsParams {
115 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
116 diagnostics: vec![diagnostic],
117 version: None,
118 },
119 None,
120 language::DiagnosticSourceKind::Pushed,
121 &[],
122 cx,
123 )
124 .unwrap();
125 });
126 });
127
128 let (request, respond_tx) = requests.predict.next().await.unwrap();
129 respond_tx
130 .send(model_response(
131 &request,
132 indoc! {r#"
133 --- a/root/2.txt
134 +++ b/root/2.txt
135 @@ ... @@
136 Hola!
137 -Como
138 +Como estas?
139 Adios
140 "#},
141 ))
142 .unwrap();
143 cx.run_until_parked();
144
145 ep_store.update(cx, |ep_store, cx| {
146 let prediction = ep_store
147 .prediction_at(&buffer1, None, &project, cx)
148 .unwrap();
149 assert_matches!(
150 prediction,
151 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
152 );
153 });
154
155 let buffer2 = project
156 .update(cx, |project, cx| {
157 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
158 project.open_buffer(path, cx)
159 })
160 .await
161 .unwrap();
162
163 ep_store.update(cx, |ep_store, cx| {
164 let prediction = ep_store
165 .prediction_at(&buffer2, None, &project, cx)
166 .unwrap();
167 assert_matches!(prediction, BufferEditPrediction::Local { .. });
168 });
169}
170
171#[gpui::test]
172async fn test_simple_request(cx: &mut TestAppContext) {
173 let (ep_store, mut requests) = init_test_with_fake_client(cx);
174 let fs = FakeFs::new(cx.executor());
175 fs.insert_tree(
176 "/root",
177 json!({
178 "foo.md": "Hello!\nHow\nBye\n"
179 }),
180 )
181 .await;
182 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
183
184 let buffer = project
185 .update(cx, |project, cx| {
186 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
187 project.open_buffer(path, cx)
188 })
189 .await
190 .unwrap();
191 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
192 let position = snapshot.anchor_before(language::Point::new(1, 3));
193
194 let prediction_task = ep_store.update(cx, |ep_store, cx| {
195 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
196 });
197
198 let (request, respond_tx) = requests.predict.next().await.unwrap();
199
200 // TODO Put back when we have a structured request again
201 // assert_eq!(
202 // request.excerpt_path.as_ref(),
203 // Path::new(path!("root/foo.md"))
204 // );
205 // assert_eq!(
206 // request.cursor_point,
207 // Point {
208 // line: Line(1),
209 // column: 3
210 // }
211 // );
212
213 respond_tx
214 .send(model_response(
215 &request,
216 indoc! { r"
217 --- a/root/foo.md
218 +++ b/root/foo.md
219 @@ ... @@
220 Hello!
221 -How
222 +How are you?
223 Bye
224 "},
225 ))
226 .unwrap();
227
228 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
229
230 assert_eq!(prediction.edits.len(), 1);
231 assert_eq!(
232 prediction.edits[0].0.to_point(&snapshot).start,
233 language::Point::new(1, 3)
234 );
235 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
236}
237
238#[gpui::test]
239async fn test_request_events(cx: &mut TestAppContext) {
240 let (ep_store, mut requests) = init_test_with_fake_client(cx);
241 let fs = FakeFs::new(cx.executor());
242 fs.insert_tree(
243 "/root",
244 json!({
245 "foo.md": "Hello!\n\nBye\n"
246 }),
247 )
248 .await;
249 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
250
251 let buffer = project
252 .update(cx, |project, cx| {
253 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
254 project.open_buffer(path, cx)
255 })
256 .await
257 .unwrap();
258
259 ep_store.update(cx, |ep_store, cx| {
260 ep_store.register_buffer(&buffer, &project, cx);
261 });
262
263 buffer.update(cx, |buffer, cx| {
264 buffer.edit(vec![(7..7, "How")], None, cx);
265 });
266
267 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
268 let position = snapshot.anchor_before(language::Point::new(1, 3));
269
270 let prediction_task = ep_store.update(cx, |ep_store, cx| {
271 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
272 });
273
274 let (request, respond_tx) = requests.predict.next().await.unwrap();
275
276 let prompt = prompt_from_request(&request);
277 assert!(
278 prompt.contains(indoc! {"
279 --- a/root/foo.md
280 +++ b/root/foo.md
281 @@ -1,3 +1,3 @@
282 Hello!
283 -
284 +How
285 Bye
286 "}),
287 "{prompt}"
288 );
289
290 respond_tx
291 .send(model_response(
292 &request,
293 indoc! {r#"
294 --- a/root/foo.md
295 +++ b/root/foo.md
296 @@ ... @@
297 Hello!
298 -How
299 +How are you?
300 Bye
301 "#},
302 ))
303 .unwrap();
304
305 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
306
307 assert_eq!(prediction.edits.len(), 1);
308 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
309}
310
311#[gpui::test]
312async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
313 let (ep_store, _requests) = init_test_with_fake_client(cx);
314 let fs = FakeFs::new(cx.executor());
315 fs.insert_tree(
316 "/root",
317 json!({
318 "foo.md": "Hello!\n\nBye\n"
319 }),
320 )
321 .await;
322 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
323
324 let buffer = project
325 .update(cx, |project, cx| {
326 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
327 project.open_buffer(path, cx)
328 })
329 .await
330 .unwrap();
331
332 ep_store.update(cx, |ep_store, cx| {
333 ep_store.register_buffer(&buffer, &project, cx);
334 });
335
336 // First burst: insert "How"
337 buffer.update(cx, |buffer, cx| {
338 buffer.edit(vec![(7..7, "How")], None, cx);
339 });
340
341 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
342 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
343 cx.run_until_parked();
344
345 // Second burst: append " are you?" immediately after "How" on the same line.
346 //
347 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
348 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
349 buffer.update(cx, |buffer, cx| {
350 buffer.edit(vec![(10..10, " are you?")], None, cx);
351 });
352
353 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
354 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
355 buffer.update(cx, |buffer, cx| {
356 buffer.edit(vec![(19..19, "!")], None, cx);
357 });
358
359 // With time-based splitting, there are two distinct events.
360 let events = ep_store.update(cx, |ep_store, cx| {
361 ep_store.edit_history_for_project(&project, cx)
362 });
363 assert_eq!(events.len(), 2);
364 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
365 assert_eq!(
366 diff.as_str(),
367 indoc! {"
368 @@ -1,3 +1,3 @@
369 Hello!
370 -
371 +How
372 Bye
373 "}
374 );
375
376 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
377 assert_eq!(
378 diff.as_str(),
379 indoc! {"
380 @@ -1,3 +1,3 @@
381 Hello!
382 -How
383 +How are you?!
384 Bye
385 "}
386 );
387}
388
389#[gpui::test]
390async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
391 let (ep_store, _requests) = init_test_with_fake_client(cx);
392 let fs = FakeFs::new(cx.executor());
393
394 // Create a file with 30 lines to test line-based coalescing
395 let content = (1..=30)
396 .map(|i| format!("Line {}\n", i))
397 .collect::<String>();
398 fs.insert_tree(
399 "/root",
400 json!({
401 "foo.md": content
402 }),
403 )
404 .await;
405 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
406
407 let buffer = project
408 .update(cx, |project, cx| {
409 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
410 project.open_buffer(path, cx)
411 })
412 .await
413 .unwrap();
414
415 ep_store.update(cx, |ep_store, cx| {
416 ep_store.register_buffer(&buffer, &project, cx);
417 });
418
419 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
420 buffer.update(cx, |buffer, cx| {
421 let start = Point::new(10, 0).to_offset(buffer);
422 let end = Point::new(13, 0).to_offset(buffer);
423 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
424 });
425
426 let events = ep_store.update(cx, |ep_store, cx| {
427 ep_store.edit_history_for_project(&project, cx)
428 });
429 assert_eq!(
430 render_events(&events),
431 indoc! {"
432 @@ -8,9 +8,8 @@
433 Line 8
434 Line 9
435 Line 10
436 -Line 11
437 -Line 12
438 -Line 13
439 +Middle A
440 +Middle B
441 Line 14
442 Line 15
443 Line 16
444 "},
445 "After first edit"
446 );
447
448 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
449 // This tests that coalescing considers the START of the existing range
450 buffer.update(cx, |buffer, cx| {
451 let offset = Point::new(5, 0).to_offset(buffer);
452 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
453 });
454
455 let events = ep_store.update(cx, |ep_store, cx| {
456 ep_store.edit_history_for_project(&project, cx)
457 });
458 assert_eq!(
459 render_events(&events),
460 indoc! {"
461 @@ -3,14 +3,14 @@
462 Line 3
463 Line 4
464 Line 5
465 +Above
466 Line 6
467 Line 7
468 Line 8
469 Line 9
470 Line 10
471 -Line 11
472 -Line 12
473 -Line 13
474 +Middle A
475 +Middle B
476 Line 14
477 Line 15
478 Line 16
479 "},
480 "After inserting above (should coalesce)"
481 );
482
483 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
484 // This tests that coalescing considers the END of the existing range
485 buffer.update(cx, |buffer, cx| {
486 let offset = Point::new(14, 0).to_offset(buffer);
487 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
488 });
489
490 let events = ep_store.update(cx, |ep_store, cx| {
491 ep_store.edit_history_for_project(&project, cx)
492 });
493 assert_eq!(
494 render_events(&events),
495 indoc! {"
496 @@ -3,15 +3,16 @@
497 Line 3
498 Line 4
499 Line 5
500 +Above
501 Line 6
502 Line 7
503 Line 8
504 Line 9
505 Line 10
506 -Line 11
507 -Line 12
508 -Line 13
509 +Middle A
510 +Middle B
511 Line 14
512 +Below
513 Line 15
514 Line 16
515 Line 17
516 "},
517 "After inserting below (should coalesce)"
518 );
519
520 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
521 // This should NOT coalesce - creates a new event
522 buffer.update(cx, |buffer, cx| {
523 let offset = Point::new(25, 0).to_offset(buffer);
524 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
525 });
526
527 let events = ep_store.update(cx, |ep_store, cx| {
528 ep_store.edit_history_for_project(&project, cx)
529 });
530 assert_eq!(
531 render_events(&events),
532 indoc! {"
533 @@ -3,15 +3,16 @@
534 Line 3
535 Line 4
536 Line 5
537 +Above
538 Line 6
539 Line 7
540 Line 8
541 Line 9
542 Line 10
543 -Line 11
544 -Line 12
545 -Line 13
546 +Middle A
547 +Middle B
548 Line 14
549 +Below
550 Line 15
551 Line 16
552 Line 17
553
554 ---
555 @@ -23,6 +23,7 @@
556 Line 22
557 Line 23
558 Line 24
559 +Far below
560 Line 25
561 Line 26
562 Line 27
563 "},
564 "After inserting far below (should NOT coalesce)"
565 );
566}
567
568fn render_events(events: &[StoredEvent]) -> String {
569 events
570 .iter()
571 .map(|e| {
572 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
573 diff.as_str()
574 })
575 .collect::<Vec<_>>()
576 .join("\n---\n")
577}
578
579fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
580 events
581 .iter()
582 .map(|e| {
583 let zeta_prompt::Event::BufferChange {
584 diff, predicted, ..
585 } = e.event.as_ref();
586 let prefix = if *predicted { "predicted" } else { "manual" };
587 format!("{}\n{}", prefix, diff)
588 })
589 .collect()
590}
591
592#[gpui::test]
593async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
594 let (ep_store, _requests) = init_test_with_fake_client(cx);
595 let fs = FakeFs::new(cx.executor());
596 fs.insert_tree(
597 "/root",
598 json!({
599 "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
600 }),
601 )
602 .await;
603 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
604
605 let buffer = project
606 .update(cx, |project, cx| {
607 let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
608 project.open_buffer(path, cx)
609 })
610 .await
611 .unwrap();
612
613 ep_store.update(cx, |ep_store, cx| {
614 ep_store.register_buffer(&buffer, &project, cx);
615 });
616
617 // Case 1: Manual edits have `predicted` set to false.
618 buffer.update(cx, |buffer, cx| {
619 buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
620 });
621
622 let events = ep_store.update(cx, |ep_store, cx| {
623 ep_store.edit_history_for_project(&project, cx)
624 });
625
626 assert_eq!(
627 render_events_with_predicted(&events),
628 vec![indoc! {"
629 manual
630 @@ -1,4 +1,4 @@
631 -line 0
632 +LINE ZERO
633 line 1
634 line 2
635 line 3
636 "}]
637 );
638
639 // Case 2: Multiple successive manual edits near each other are merged into one
640 // event with `predicted` set to false.
641 buffer.update(cx, |buffer, cx| {
642 let offset = Point::new(1, 0).to_offset(buffer);
643 let end = Point::new(1, 6).to_offset(buffer);
644 buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
645 });
646
647 let events = ep_store.update(cx, |ep_store, cx| {
648 ep_store.edit_history_for_project(&project, cx)
649 });
650 assert_eq!(
651 render_events_with_predicted(&events),
652 vec![indoc! {"
653 manual
654 @@ -1,5 +1,5 @@
655 -line 0
656 -line 1
657 +LINE ZERO
658 +LINE ONE
659 line 2
660 line 3
661 line 4
662 "}]
663 );
664
665 // Case 3: Accepted predictions have `predicted` set to true.
666 // Case 5: A manual edit that follows a predicted edit is not merged with the
667 // predicted edit, even if it is nearby.
668 ep_store.update(cx, |ep_store, cx| {
669 buffer.update(cx, |buffer, cx| {
670 let offset = Point::new(2, 0).to_offset(buffer);
671 let end = Point::new(2, 6).to_offset(buffer);
672 buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
673 });
674 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
675 });
676
677 let events = ep_store.update(cx, |ep_store, cx| {
678 ep_store.edit_history_for_project(&project, cx)
679 });
680 assert_eq!(
681 render_events_with_predicted(&events),
682 vec![
683 indoc! {"
684 manual
685 @@ -1,5 +1,5 @@
686 -line 0
687 -line 1
688 +LINE ZERO
689 +LINE ONE
690 line 2
691 line 3
692 line 4
693 "},
694 indoc! {"
695 predicted
696 @@ -1,6 +1,6 @@
697 LINE ZERO
698 LINE ONE
699 -line 2
700 +LINE TWO
701 line 3
702 line 4
703 line 5
704 "}
705 ]
706 );
707
708 // Case 4: Multiple successive accepted predictions near each other are merged
709 // into one event with `predicted` set to true.
710 ep_store.update(cx, |ep_store, cx| {
711 buffer.update(cx, |buffer, cx| {
712 let offset = Point::new(3, 0).to_offset(buffer);
713 let end = Point::new(3, 6).to_offset(buffer);
714 buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
715 });
716 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
717 });
718
719 let events = ep_store.update(cx, |ep_store, cx| {
720 ep_store.edit_history_for_project(&project, cx)
721 });
722 assert_eq!(
723 render_events_with_predicted(&events),
724 vec![
725 indoc! {"
726 manual
727 @@ -1,5 +1,5 @@
728 -line 0
729 -line 1
730 +LINE ZERO
731 +LINE ONE
732 line 2
733 line 3
734 line 4
735 "},
736 indoc! {"
737 predicted
738 @@ -1,7 +1,7 @@
739 LINE ZERO
740 LINE ONE
741 -line 2
742 -line 3
743 +LINE TWO
744 +LINE THREE
745 line 4
746 line 5
747 line 6
748 "}
749 ]
750 );
751
752 // Case 5 (continued): A manual edit that follows a predicted edit is not merged
753 // with the predicted edit, even if it is nearby.
754 buffer.update(cx, |buffer, cx| {
755 let offset = Point::new(4, 0).to_offset(buffer);
756 let end = Point::new(4, 6).to_offset(buffer);
757 buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
758 });
759
760 let events = ep_store.update(cx, |ep_store, cx| {
761 ep_store.edit_history_for_project(&project, cx)
762 });
763 assert_eq!(
764 render_events_with_predicted(&events),
765 vec![
766 indoc! {"
767 manual
768 @@ -1,5 +1,5 @@
769 -line 0
770 -line 1
771 +LINE ZERO
772 +LINE ONE
773 line 2
774 line 3
775 line 4
776 "},
777 indoc! {"
778 predicted
779 @@ -1,7 +1,7 @@
780 LINE ZERO
781 LINE ONE
782 -line 2
783 -line 3
784 +LINE TWO
785 +LINE THREE
786 line 4
787 line 5
788 line 6
789 "},
790 indoc! {"
791 manual
792 @@ -2,7 +2,7 @@
793 LINE ONE
794 LINE TWO
795 LINE THREE
796 -line 4
797 +LINE FOUR
798 line 5
799 line 6
800 line 7
801 "}
802 ]
803 );
804
805 // Case 6: If we then perform a manual edit at a *different* location (more than
806 // 8 lines away), then the edits at the prior location can be merged with each
807 // other, even if some are predicted and some are not. `predicted` means all
808 // constituent edits were predicted.
809 buffer.update(cx, |buffer, cx| {
810 let offset = Point::new(14, 0).to_offset(buffer);
811 let end = Point::new(14, 7).to_offset(buffer);
812 buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
813 });
814
815 let events = ep_store.update(cx, |ep_store, cx| {
816 ep_store.edit_history_for_project(&project, cx)
817 });
818 assert_eq!(
819 render_events_with_predicted(&events),
820 vec![
821 indoc! {"
822 manual
823 @@ -1,8 +1,8 @@
824 -line 0
825 -line 1
826 -line 2
827 -line 3
828 -line 4
829 +LINE ZERO
830 +LINE ONE
831 +LINE TWO
832 +LINE THREE
833 +LINE FOUR
834 line 5
835 line 6
836 line 7
837 "},
838 indoc! {"
839 manual
840 @@ -12,4 +12,4 @@
841 line 11
842 line 12
843 line 13
844 -line 14
845 +LINE FOURTEEN
846 "}
847 ]
848 );
849}
850
851#[gpui::test]
852async fn test_empty_prediction(cx: &mut TestAppContext) {
853 let (ep_store, mut requests) = init_test_with_fake_client(cx);
854 let fs = FakeFs::new(cx.executor());
855 fs.insert_tree(
856 "/root",
857 json!({
858 "foo.md": "Hello!\nHow\nBye\n"
859 }),
860 )
861 .await;
862 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
863
864 let buffer = project
865 .update(cx, |project, cx| {
866 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
867 project.open_buffer(path, cx)
868 })
869 .await
870 .unwrap();
871 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
872 let position = snapshot.anchor_before(language::Point::new(1, 3));
873
874 ep_store.update(cx, |ep_store, cx| {
875 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
876 });
877
878 let (request, respond_tx) = requests.predict.next().await.unwrap();
879 let response = model_response(&request, "");
880 let id = response.request_id.clone();
881 respond_tx.send(response).unwrap();
882
883 cx.run_until_parked();
884
885 ep_store.update(cx, |ep_store, cx| {
886 assert!(
887 ep_store
888 .prediction_at(&buffer, None, &project, cx)
889 .is_none()
890 );
891 });
892
893 // prediction is reported as rejected
894 let (reject_request, _) = requests.reject.next().await.unwrap();
895
896 assert_eq!(
897 &reject_request.rejections,
898 &[EditPredictionRejection {
899 request_id: id,
900 reason: EditPredictionRejectReason::Empty,
901 was_shown: false
902 }]
903 );
904}
905
906#[gpui::test]
907async fn test_interpolated_empty(cx: &mut TestAppContext) {
908 let (ep_store, mut requests) = init_test_with_fake_client(cx);
909 let fs = FakeFs::new(cx.executor());
910 fs.insert_tree(
911 "/root",
912 json!({
913 "foo.md": "Hello!\nHow\nBye\n"
914 }),
915 )
916 .await;
917 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
918
919 let buffer = project
920 .update(cx, |project, cx| {
921 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
922 project.open_buffer(path, cx)
923 })
924 .await
925 .unwrap();
926 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
927 let position = snapshot.anchor_before(language::Point::new(1, 3));
928
929 ep_store.update(cx, |ep_store, cx| {
930 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
931 });
932
933 let (request, respond_tx) = requests.predict.next().await.unwrap();
934
935 buffer.update(cx, |buffer, cx| {
936 buffer.set_text("Hello!\nHow are you?\nBye", cx);
937 });
938
939 let response = model_response(&request, SIMPLE_DIFF);
940 let id = response.request_id.clone();
941 respond_tx.send(response).unwrap();
942
943 cx.run_until_parked();
944
945 ep_store.update(cx, |ep_store, cx| {
946 assert!(
947 ep_store
948 .prediction_at(&buffer, None, &project, cx)
949 .is_none()
950 );
951 });
952
953 // prediction is reported as rejected
954 let (reject_request, _) = requests.reject.next().await.unwrap();
955
956 assert_eq!(
957 &reject_request.rejections,
958 &[EditPredictionRejection {
959 request_id: id,
960 reason: EditPredictionRejectReason::InterpolatedEmpty,
961 was_shown: false
962 }]
963 );
964}
965
966const SIMPLE_DIFF: &str = indoc! { r"
967 --- a/root/foo.md
968 +++ b/root/foo.md
969 @@ ... @@
970 Hello!
971 -How
972 +How are you?
973 Bye
974"};
975
976#[gpui::test]
977async fn test_replace_current(cx: &mut TestAppContext) {
978 let (ep_store, mut requests) = init_test_with_fake_client(cx);
979 let fs = FakeFs::new(cx.executor());
980 fs.insert_tree(
981 "/root",
982 json!({
983 "foo.md": "Hello!\nHow\nBye\n"
984 }),
985 )
986 .await;
987 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
988
989 let buffer = project
990 .update(cx, |project, cx| {
991 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
992 project.open_buffer(path, cx)
993 })
994 .await
995 .unwrap();
996 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
997 let position = snapshot.anchor_before(language::Point::new(1, 3));
998
999 ep_store.update(cx, |ep_store, cx| {
1000 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1001 });
1002
1003 let (request, respond_tx) = requests.predict.next().await.unwrap();
1004 let first_response = model_response(&request, SIMPLE_DIFF);
1005 let first_id = first_response.request_id.clone();
1006 respond_tx.send(first_response).unwrap();
1007
1008 cx.run_until_parked();
1009
1010 ep_store.update(cx, |ep_store, cx| {
1011 assert_eq!(
1012 ep_store
1013 .prediction_at(&buffer, None, &project, cx)
1014 .unwrap()
1015 .id
1016 .0,
1017 first_id
1018 );
1019 });
1020
1021 // a second request is triggered
1022 ep_store.update(cx, |ep_store, cx| {
1023 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1024 });
1025
1026 let (request, respond_tx) = requests.predict.next().await.unwrap();
1027 let second_response = model_response(&request, SIMPLE_DIFF);
1028 let second_id = second_response.request_id.clone();
1029 respond_tx.send(second_response).unwrap();
1030
1031 cx.run_until_parked();
1032
1033 ep_store.update(cx, |ep_store, cx| {
1034 // second replaces first
1035 assert_eq!(
1036 ep_store
1037 .prediction_at(&buffer, None, &project, cx)
1038 .unwrap()
1039 .id
1040 .0,
1041 second_id
1042 );
1043 });
1044
1045 // first is reported as replaced
1046 let (reject_request, _) = requests.reject.next().await.unwrap();
1047
1048 assert_eq!(
1049 &reject_request.rejections,
1050 &[EditPredictionRejection {
1051 request_id: first_id,
1052 reason: EditPredictionRejectReason::Replaced,
1053 was_shown: false
1054 }]
1055 );
1056}
1057
1058#[gpui::test]
1059async fn test_current_preferred(cx: &mut TestAppContext) {
1060 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1061 let fs = FakeFs::new(cx.executor());
1062 fs.insert_tree(
1063 "/root",
1064 json!({
1065 "foo.md": "Hello!\nHow\nBye\n"
1066 }),
1067 )
1068 .await;
1069 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1070
1071 let buffer = project
1072 .update(cx, |project, cx| {
1073 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1074 project.open_buffer(path, cx)
1075 })
1076 .await
1077 .unwrap();
1078 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1079 let position = snapshot.anchor_before(language::Point::new(1, 3));
1080
1081 ep_store.update(cx, |ep_store, cx| {
1082 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1083 });
1084
1085 let (request, respond_tx) = requests.predict.next().await.unwrap();
1086 let first_response = model_response(&request, SIMPLE_DIFF);
1087 let first_id = first_response.request_id.clone();
1088 respond_tx.send(first_response).unwrap();
1089
1090 cx.run_until_parked();
1091
1092 ep_store.update(cx, |ep_store, cx| {
1093 assert_eq!(
1094 ep_store
1095 .prediction_at(&buffer, None, &project, cx)
1096 .unwrap()
1097 .id
1098 .0,
1099 first_id
1100 );
1101 });
1102
1103 // a second request is triggered
1104 ep_store.update(cx, |ep_store, cx| {
1105 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1106 });
1107
1108 let (request, respond_tx) = requests.predict.next().await.unwrap();
1109 // worse than current prediction
1110 let second_response = model_response(
1111 &request,
1112 indoc! { r"
1113 --- a/root/foo.md
1114 +++ b/root/foo.md
1115 @@ ... @@
1116 Hello!
1117 -How
1118 +How are
1119 Bye
1120 "},
1121 );
1122 let second_id = second_response.request_id.clone();
1123 respond_tx.send(second_response).unwrap();
1124
1125 cx.run_until_parked();
1126
1127 ep_store.update(cx, |ep_store, cx| {
1128 // first is preferred over second
1129 assert_eq!(
1130 ep_store
1131 .prediction_at(&buffer, None, &project, cx)
1132 .unwrap()
1133 .id
1134 .0,
1135 first_id
1136 );
1137 });
1138
1139 // second is reported as rejected
1140 let (reject_request, _) = requests.reject.next().await.unwrap();
1141
1142 assert_eq!(
1143 &reject_request.rejections,
1144 &[EditPredictionRejection {
1145 request_id: second_id,
1146 reason: EditPredictionRejectReason::CurrentPreferred,
1147 was_shown: false
1148 }]
1149 );
1150}
1151
1152#[gpui::test]
1153async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1154 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1155 let fs = FakeFs::new(cx.executor());
1156 fs.insert_tree(
1157 "/root",
1158 json!({
1159 "foo.md": "Hello!\nHow\nBye\n"
1160 }),
1161 )
1162 .await;
1163 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1164
1165 let buffer = project
1166 .update(cx, |project, cx| {
1167 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1168 project.open_buffer(path, cx)
1169 })
1170 .await
1171 .unwrap();
1172 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1173 let position = snapshot.anchor_before(language::Point::new(1, 3));
1174
1175 // start two refresh tasks
1176 ep_store.update(cx, |ep_store, cx| {
1177 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1178 });
1179
1180 let (request1, respond_first) = requests.predict.next().await.unwrap();
1181
1182 ep_store.update(cx, |ep_store, cx| {
1183 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1184 });
1185
1186 let (request, respond_second) = requests.predict.next().await.unwrap();
1187
1188 // wait for throttle
1189 cx.run_until_parked();
1190
1191 // second responds first
1192 let second_response = model_response(&request, SIMPLE_DIFF);
1193 let second_id = second_response.request_id.clone();
1194 respond_second.send(second_response).unwrap();
1195
1196 cx.run_until_parked();
1197
1198 ep_store.update(cx, |ep_store, cx| {
1199 // current prediction is second
1200 assert_eq!(
1201 ep_store
1202 .prediction_at(&buffer, None, &project, cx)
1203 .unwrap()
1204 .id
1205 .0,
1206 second_id
1207 );
1208 });
1209
1210 let first_response = model_response(&request1, SIMPLE_DIFF);
1211 let first_id = first_response.request_id.clone();
1212 respond_first.send(first_response).unwrap();
1213
1214 cx.run_until_parked();
1215
1216 ep_store.update(cx, |ep_store, cx| {
1217 // current prediction is still second, since first was cancelled
1218 assert_eq!(
1219 ep_store
1220 .prediction_at(&buffer, None, &project, cx)
1221 .unwrap()
1222 .id
1223 .0,
1224 second_id
1225 );
1226 });
1227
1228 // first is reported as rejected
1229 let (reject_request, _) = requests.reject.next().await.unwrap();
1230
1231 cx.run_until_parked();
1232
1233 assert_eq!(
1234 &reject_request.rejections,
1235 &[EditPredictionRejection {
1236 request_id: first_id,
1237 reason: EditPredictionRejectReason::Canceled,
1238 was_shown: false
1239 }]
1240 );
1241}
1242
1243#[gpui::test]
1244async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1245 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1246 let fs = FakeFs::new(cx.executor());
1247 fs.insert_tree(
1248 "/root",
1249 json!({
1250 "foo.md": "Hello!\nHow\nBye\n"
1251 }),
1252 )
1253 .await;
1254 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1255
1256 let buffer = project
1257 .update(cx, |project, cx| {
1258 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1259 project.open_buffer(path, cx)
1260 })
1261 .await
1262 .unwrap();
1263 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1264 let position = snapshot.anchor_before(language::Point::new(1, 3));
1265
1266 // start two refresh tasks
1267 ep_store.update(cx, |ep_store, cx| {
1268 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1269 });
1270
1271 let (request1, respond_first) = requests.predict.next().await.unwrap();
1272
1273 ep_store.update(cx, |ep_store, cx| {
1274 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1275 });
1276
1277 let (request2, respond_second) = requests.predict.next().await.unwrap();
1278
1279 // wait for throttle, so requests are sent
1280 cx.run_until_parked();
1281
1282 ep_store.update(cx, |ep_store, cx| {
1283 // start a third request
1284 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1285
1286 // 2 are pending, so 2nd is cancelled
1287 assert_eq!(
1288 ep_store
1289 .get_or_init_project(&project, cx)
1290 .cancelled_predictions
1291 .iter()
1292 .copied()
1293 .collect::<Vec<_>>(),
1294 [1]
1295 );
1296 });
1297
1298 // wait for throttle
1299 cx.run_until_parked();
1300
1301 let (request3, respond_third) = requests.predict.next().await.unwrap();
1302
1303 let first_response = model_response(&request1, SIMPLE_DIFF);
1304 let first_id = first_response.request_id.clone();
1305 respond_first.send(first_response).unwrap();
1306
1307 cx.run_until_parked();
1308
1309 ep_store.update(cx, |ep_store, cx| {
1310 // current prediction is first
1311 assert_eq!(
1312 ep_store
1313 .prediction_at(&buffer, None, &project, cx)
1314 .unwrap()
1315 .id
1316 .0,
1317 first_id
1318 );
1319 });
1320
1321 let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1322 let cancelled_id = cancelled_response.request_id.clone();
1323 respond_second.send(cancelled_response).unwrap();
1324
1325 cx.run_until_parked();
1326
1327 ep_store.update(cx, |ep_store, cx| {
1328 // current prediction is still first, since second was cancelled
1329 assert_eq!(
1330 ep_store
1331 .prediction_at(&buffer, None, &project, cx)
1332 .unwrap()
1333 .id
1334 .0,
1335 first_id
1336 );
1337 });
1338
1339 let third_response = model_response(&request3, SIMPLE_DIFF);
1340 let third_response_id = third_response.request_id.clone();
1341 respond_third.send(third_response).unwrap();
1342
1343 cx.run_until_parked();
1344
1345 ep_store.update(cx, |ep_store, cx| {
1346 // third completes and replaces first
1347 assert_eq!(
1348 ep_store
1349 .prediction_at(&buffer, None, &project, cx)
1350 .unwrap()
1351 .id
1352 .0,
1353 third_response_id
1354 );
1355 });
1356
1357 // second is reported as rejected
1358 let (reject_request, _) = requests.reject.next().await.unwrap();
1359
1360 cx.run_until_parked();
1361
1362 assert_eq!(
1363 &reject_request.rejections,
1364 &[
1365 EditPredictionRejection {
1366 request_id: cancelled_id,
1367 reason: EditPredictionRejectReason::Canceled,
1368 was_shown: false
1369 },
1370 EditPredictionRejection {
1371 request_id: first_id,
1372 reason: EditPredictionRejectReason::Replaced,
1373 was_shown: false
1374 }
1375 ]
1376 );
1377}
1378
1379#[gpui::test]
1380async fn test_rejections_flushing(cx: &mut TestAppContext) {
1381 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1382
1383 ep_store.update(cx, |ep_store, cx| {
1384 ep_store.reject_prediction(
1385 EditPredictionId("test-1".into()),
1386 EditPredictionRejectReason::Discarded,
1387 false,
1388 cx,
1389 );
1390 ep_store.reject_prediction(
1391 EditPredictionId("test-2".into()),
1392 EditPredictionRejectReason::Canceled,
1393 true,
1394 cx,
1395 );
1396 });
1397
1398 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1399 cx.run_until_parked();
1400
1401 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1402 respond_tx.send(()).unwrap();
1403
1404 // batched
1405 assert_eq!(reject_request.rejections.len(), 2);
1406 assert_eq!(
1407 reject_request.rejections[0],
1408 EditPredictionRejection {
1409 request_id: "test-1".to_string(),
1410 reason: EditPredictionRejectReason::Discarded,
1411 was_shown: false
1412 }
1413 );
1414 assert_eq!(
1415 reject_request.rejections[1],
1416 EditPredictionRejection {
1417 request_id: "test-2".to_string(),
1418 reason: EditPredictionRejectReason::Canceled,
1419 was_shown: true
1420 }
1421 );
1422
1423 // Reaching batch size limit sends without debounce
1424 ep_store.update(cx, |ep_store, cx| {
1425 for i in 0..70 {
1426 ep_store.reject_prediction(
1427 EditPredictionId(format!("batch-{}", i).into()),
1428 EditPredictionRejectReason::Discarded,
1429 false,
1430 cx,
1431 );
1432 }
1433 });
1434
1435 // First MAX/2 items are sent immediately
1436 cx.run_until_parked();
1437 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1438 respond_tx.send(()).unwrap();
1439
1440 assert_eq!(reject_request.rejections.len(), 50);
1441 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1442 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1443
1444 // Remaining items are debounced with the next batch
1445 cx.executor().advance_clock(Duration::from_secs(15));
1446 cx.run_until_parked();
1447
1448 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1449 respond_tx.send(()).unwrap();
1450
1451 assert_eq!(reject_request.rejections.len(), 20);
1452 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1453 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1454
1455 // Request failure
1456 ep_store.update(cx, |ep_store, cx| {
1457 ep_store.reject_prediction(
1458 EditPredictionId("retry-1".into()),
1459 EditPredictionRejectReason::Discarded,
1460 false,
1461 cx,
1462 );
1463 });
1464
1465 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1466 cx.run_until_parked();
1467
1468 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1469 assert_eq!(reject_request.rejections.len(), 1);
1470 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1471 // Simulate failure
1472 drop(_respond_tx);
1473
1474 // Add another rejection
1475 ep_store.update(cx, |ep_store, cx| {
1476 ep_store.reject_prediction(
1477 EditPredictionId("retry-2".into()),
1478 EditPredictionRejectReason::Discarded,
1479 false,
1480 cx,
1481 );
1482 });
1483
1484 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1485 cx.run_until_parked();
1486
1487 // Retry should include both the failed item and the new one
1488 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1489 respond_tx.send(()).unwrap();
1490
1491 assert_eq!(reject_request.rejections.len(), 2);
1492 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1493 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1494}
1495
1496// Skipped until we start including diagnostics in prompt
1497// #[gpui::test]
1498// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1499// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1500// let fs = FakeFs::new(cx.executor());
1501// fs.insert_tree(
1502// "/root",
1503// json!({
1504// "foo.md": "Hello!\nBye"
1505// }),
1506// )
1507// .await;
1508// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1509
1510// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1511// let diagnostic = lsp::Diagnostic {
1512// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1513// severity: Some(lsp::DiagnosticSeverity::ERROR),
1514// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1515// ..Default::default()
1516// };
1517
1518// project.update(cx, |project, cx| {
1519// project.lsp_store().update(cx, |lsp_store, cx| {
1520// // Create some diagnostics
1521// lsp_store
1522// .update_diagnostics(
1523// LanguageServerId(0),
1524// lsp::PublishDiagnosticsParams {
1525// uri: path_to_buffer_uri.clone(),
1526// diagnostics: vec![diagnostic],
1527// version: None,
1528// },
1529// None,
1530// language::DiagnosticSourceKind::Pushed,
1531// &[],
1532// cx,
1533// )
1534// .unwrap();
1535// });
1536// });
1537
1538// let buffer = project
1539// .update(cx, |project, cx| {
1540// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1541// project.open_buffer(path, cx)
1542// })
1543// .await
1544// .unwrap();
1545
1546// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1547// let position = snapshot.anchor_before(language::Point::new(0, 0));
1548
1549// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1550// ep_store.request_prediction(&project, &buffer, position, cx)
1551// });
1552
1553// let (request, _respond_tx) = req_rx.next().await.unwrap();
1554
1555// assert_eq!(request.diagnostic_groups.len(), 1);
1556// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1557// .unwrap();
1558// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1559// assert_eq!(
1560// value,
1561// json!({
1562// "entries": [{
1563// "range": {
1564// "start": 8,
1565// "end": 10
1566// },
1567// "diagnostic": {
1568// "source": null,
1569// "code": null,
1570// "code_description": null,
1571// "severity": 1,
1572// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1573// "markdown": null,
1574// "group_id": 0,
1575// "is_primary": true,
1576// "is_disk_based": false,
1577// "is_unnecessary": false,
1578// "source_kind": "Pushed",
1579// "data": null,
1580// "underline": true
1581// }
1582// }],
1583// "primary_ix": 0
1584// })
1585// );
1586// }
1587
1588// Generate a model response that would apply the given diff to the active file.
1589fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1590 let excerpt =
1591 request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1592 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1593
1594 PredictEditsV3Response {
1595 request_id: Uuid::new_v4().to_string(),
1596 output: new_excerpt,
1597 }
1598}
1599
1600fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1601 zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1602}
1603
1604struct RequestChannels {
1605 predict: mpsc::UnboundedReceiver<(
1606 PredictEditsV3Request,
1607 oneshot::Sender<PredictEditsV3Response>,
1608 )>,
1609 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1610}
1611
1612fn init_test_with_fake_client(
1613 cx: &mut TestAppContext,
1614) -> (Entity<EditPredictionStore>, RequestChannels) {
1615 cx.update(move |cx| {
1616 let settings_store = SettingsStore::test(cx);
1617 cx.set_global(settings_store);
1618 zlog::init_test();
1619
1620 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1621 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1622
1623 let http_client = FakeHttpClient::create({
1624 move |req| {
1625 let uri = req.uri().path().to_string();
1626 let mut body = req.into_body();
1627 let predict_req_tx = predict_req_tx.clone();
1628 let reject_req_tx = reject_req_tx.clone();
1629 async move {
1630 let resp = match uri.as_str() {
1631 "/client/llm_tokens" => serde_json::to_string(&json!({
1632 "token": "test"
1633 }))
1634 .unwrap(),
1635 "/predict_edits/v3" => {
1636 let mut buf = Vec::new();
1637 body.read_to_end(&mut buf).await.ok();
1638 let decompressed = zstd::decode_all(&buf[..]).unwrap();
1639 let req = serde_json::from_slice(&decompressed).unwrap();
1640
1641 let (res_tx, res_rx) = oneshot::channel();
1642 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1643 serde_json::to_string(&res_rx.await?).unwrap()
1644 }
1645 "/predict_edits/reject" => {
1646 let mut buf = Vec::new();
1647 body.read_to_end(&mut buf).await.ok();
1648 let req = serde_json::from_slice(&buf).unwrap();
1649
1650 let (res_tx, res_rx) = oneshot::channel();
1651 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1652 serde_json::to_string(&res_rx.await?).unwrap()
1653 }
1654 _ => {
1655 panic!("Unexpected path: {}", uri)
1656 }
1657 };
1658
1659 Ok(Response::builder().body(resp.into()).unwrap())
1660 }
1661 }
1662 });
1663
1664 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1665 client.cloud_client().set_credentials(1, "test".into());
1666
1667 language_model::init(client.clone(), cx);
1668
1669 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1670 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1671
1672 (
1673 ep_store,
1674 RequestChannels {
1675 predict: predict_req_rx,
1676 reject: reject_req_rx,
1677 },
1678 )
1679 })
1680}
1681
1682const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1683
1684#[gpui::test]
1685async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1686 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1687 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1688 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1689 });
1690
1691 let edit_preview = cx
1692 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1693 .await;
1694
1695 let prediction = EditPrediction {
1696 edits,
1697 cursor_position: None,
1698 edit_preview,
1699 buffer: buffer.clone(),
1700 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1701 id: EditPredictionId("the-id".into()),
1702 inputs: ZetaPromptInput {
1703 events: Default::default(),
1704 related_files: Default::default(),
1705 cursor_path: Path::new("").into(),
1706 cursor_excerpt: "".into(),
1707 editable_range_in_excerpt: 0..0,
1708 cursor_offset_in_excerpt: 0,
1709 excerpt_start_row: None,
1710 },
1711 buffer_snapshotted_at: Instant::now(),
1712 response_received_at: Instant::now(),
1713 };
1714
1715 cx.update(|cx| {
1716 assert_eq!(
1717 from_completion_edits(
1718 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1719 &buffer,
1720 cx
1721 ),
1722 vec![(2..5, "REM".into()), (9..11, "".into())]
1723 );
1724
1725 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1726 assert_eq!(
1727 from_completion_edits(
1728 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1729 &buffer,
1730 cx
1731 ),
1732 vec![(2..2, "REM".into()), (6..8, "".into())]
1733 );
1734
1735 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1736 assert_eq!(
1737 from_completion_edits(
1738 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1739 &buffer,
1740 cx
1741 ),
1742 vec![(2..5, "REM".into()), (9..11, "".into())]
1743 );
1744
1745 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1746 assert_eq!(
1747 from_completion_edits(
1748 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1749 &buffer,
1750 cx
1751 ),
1752 vec![(3..3, "EM".into()), (7..9, "".into())]
1753 );
1754
1755 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1756 assert_eq!(
1757 from_completion_edits(
1758 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1759 &buffer,
1760 cx
1761 ),
1762 vec![(4..4, "M".into()), (8..10, "".into())]
1763 );
1764
1765 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1766 assert_eq!(
1767 from_completion_edits(
1768 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1769 &buffer,
1770 cx
1771 ),
1772 vec![(9..11, "".into())]
1773 );
1774
1775 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1776 assert_eq!(
1777 from_completion_edits(
1778 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1779 &buffer,
1780 cx
1781 ),
1782 vec![(4..4, "M".into()), (8..10, "".into())]
1783 );
1784
1785 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1786 assert_eq!(
1787 from_completion_edits(
1788 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1789 &buffer,
1790 cx
1791 ),
1792 vec![(4..4, "M".into())]
1793 );
1794
1795 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1796 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1797 })
1798}
1799
1800#[gpui::test]
1801async fn test_clean_up_diff(cx: &mut TestAppContext) {
1802 init_test(cx);
1803
1804 assert_eq!(
1805 apply_edit_prediction(
1806 indoc! {"
1807 fn main() {
1808 let word_1 = \"lorem\";
1809 let range = word.len()..word.len();
1810 }
1811 "},
1812 indoc! {"
1813 <|editable_region_start|>
1814 fn main() {
1815 let word_1 = \"lorem\";
1816 let range = word_1.len()..word_1.len();
1817 }
1818
1819 <|editable_region_end|>
1820 "},
1821 cx,
1822 )
1823 .await,
1824 indoc! {"
1825 fn main() {
1826 let word_1 = \"lorem\";
1827 let range = word_1.len()..word_1.len();
1828 }
1829 "},
1830 );
1831
1832 assert_eq!(
1833 apply_edit_prediction(
1834 indoc! {"
1835 fn main() {
1836 let story = \"the quick\"
1837 }
1838 "},
1839 indoc! {"
1840 <|editable_region_start|>
1841 fn main() {
1842 let story = \"the quick brown fox jumps over the lazy dog\";
1843 }
1844
1845 <|editable_region_end|>
1846 "},
1847 cx,
1848 )
1849 .await,
1850 indoc! {"
1851 fn main() {
1852 let story = \"the quick brown fox jumps over the lazy dog\";
1853 }
1854 "},
1855 );
1856}
1857
1858#[gpui::test]
1859async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1860 init_test(cx);
1861
1862 let buffer_content = "lorem\n";
1863 let completion_response = indoc! {"
1864 ```animals.js
1865 <|start_of_file|>
1866 <|editable_region_start|>
1867 lorem
1868 ipsum
1869 <|editable_region_end|>
1870 ```"};
1871
1872 assert_eq!(
1873 apply_edit_prediction(buffer_content, completion_response, cx).await,
1874 "lorem\nipsum"
1875 );
1876}
1877
1878#[gpui::test]
1879async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1880 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1881 // When the buffer ends without a trailing newline, but the model returns output
1882 // with a trailing newline, zeta2 should normalize both sides before diffing
1883 // so no spurious newline is inserted.
1884 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1885 let fs = FakeFs::new(cx.executor());
1886
1887 // Single line buffer with no trailing newline
1888 fs.insert_tree(
1889 "/root",
1890 json!({
1891 "foo.txt": "hello"
1892 }),
1893 )
1894 .await;
1895 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1896
1897 let buffer = project
1898 .update(cx, |project, cx| {
1899 let path = project
1900 .find_project_path(path!("root/foo.txt"), cx)
1901 .unwrap();
1902 project.open_buffer(path, cx)
1903 })
1904 .await
1905 .unwrap();
1906
1907 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1908 let position = snapshot.anchor_before(language::Point::new(0, 5));
1909
1910 ep_store.update(cx, |ep_store, cx| {
1911 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1912 });
1913
1914 let (_request, respond_tx) = requests.predict.next().await.unwrap();
1915
1916 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1917 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1918 let response = PredictEditsV3Response {
1919 request_id: Uuid::new_v4().to_string(),
1920 output: "hello world\n".to_string(),
1921 };
1922 respond_tx.send(response).unwrap();
1923
1924 cx.run_until_parked();
1925
1926 // The prediction should insert " world" without adding a newline
1927 ep_store.update(cx, |ep_store, cx| {
1928 let prediction = ep_store
1929 .prediction_at(&buffer, None, &project, cx)
1930 .expect("should have prediction");
1931 let edits: Vec<_> = prediction
1932 .edits
1933 .iter()
1934 .map(|(range, text)| {
1935 let snapshot = buffer.read(cx).snapshot();
1936 (range.to_offset(&snapshot), text.clone())
1937 })
1938 .collect();
1939 assert_eq!(edits, vec![(5..5, " world".into())]);
1940 });
1941}
1942
1943#[gpui::test]
1944async fn test_can_collect_data(cx: &mut TestAppContext) {
1945 init_test(cx);
1946
1947 let fs = project::FakeFs::new(cx.executor());
1948 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1949 .await;
1950
1951 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1952 let buffer = project
1953 .update(cx, |project, cx| {
1954 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1955 })
1956 .await
1957 .unwrap();
1958
1959 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1960 ep_store.update(cx, |ep_store, _cx| {
1961 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1962 });
1963
1964 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1965 assert_eq!(
1966 captured_request.lock().clone().unwrap().can_collect_data,
1967 true
1968 );
1969
1970 ep_store.update(cx, |ep_store, _cx| {
1971 ep_store.data_collection_choice = DataCollectionChoice::Disabled
1972 });
1973
1974 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1975 assert_eq!(
1976 captured_request.lock().clone().unwrap().can_collect_data,
1977 false
1978 );
1979}
1980
1981#[gpui::test]
1982async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1983 init_test(cx);
1984
1985 let fs = project::FakeFs::new(cx.executor());
1986 let project = Project::test(fs.clone(), [], cx).await;
1987
1988 let buffer = cx.new(|_cx| {
1989 Buffer::remote(
1990 language::BufferId::new(1).unwrap(),
1991 ReplicaId::new(1),
1992 language::Capability::ReadWrite,
1993 "fn main() {\n println!(\"Hello\");\n}",
1994 )
1995 });
1996
1997 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1998 ep_store.update(cx, |ep_store, _cx| {
1999 ep_store.data_collection_choice = DataCollectionChoice::Enabled
2000 });
2001
2002 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2003 assert_eq!(
2004 captured_request.lock().clone().unwrap().can_collect_data,
2005 false
2006 );
2007}
2008
2009#[gpui::test]
2010async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
2011 init_test(cx);
2012
2013 let fs = project::FakeFs::new(cx.executor());
2014 fs.insert_tree(
2015 path!("/project"),
2016 json!({
2017 "LICENSE": BSD_0_TXT,
2018 ".env": "SECRET_KEY=secret"
2019 }),
2020 )
2021 .await;
2022
2023 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2024 let buffer = project
2025 .update(cx, |project, cx| {
2026 project.open_local_buffer("/project/.env", cx)
2027 })
2028 .await
2029 .unwrap();
2030
2031 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
2032 ep_store.update(cx, |ep_store, _cx| {
2033 ep_store.data_collection_choice = DataCollectionChoice::Enabled
2034 });
2035
2036 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2037 assert_eq!(
2038 captured_request.lock().clone().unwrap().can_collect_data,
2039 false
2040 );
2041}
2042
2043#[gpui::test]
2044async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
2045 init_test(cx);
2046
2047 let fs = project::FakeFs::new(cx.executor());
2048 let project = Project::test(fs.clone(), [], cx).await;
2049 let buffer = cx.new(|cx| Buffer::local("", cx));
2050
2051 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
2052 ep_store.update(cx, |ep_store, _cx| {
2053 ep_store.data_collection_choice = DataCollectionChoice::Enabled
2054 });
2055
2056 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2057 assert_eq!(
2058 captured_request.lock().clone().unwrap().can_collect_data,
2059 false
2060 );
2061}
2062
2063#[gpui::test]
2064async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
2065 init_test(cx);
2066
2067 let fs = project::FakeFs::new(cx.executor());
2068 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
2069 .await;
2070
2071 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2072 let buffer = project
2073 .update(cx, |project, cx| {
2074 project.open_local_buffer("/project/main.rs", cx)
2075 })
2076 .await
2077 .unwrap();
2078
2079 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
2080 ep_store.update(cx, |ep_store, _cx| {
2081 ep_store.data_collection_choice = DataCollectionChoice::Enabled
2082 });
2083
2084 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2085 assert_eq!(
2086 captured_request.lock().clone().unwrap().can_collect_data,
2087 false
2088 );
2089}
2090
2091#[gpui::test]
2092async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
2093 init_test(cx);
2094
2095 let fs = project::FakeFs::new(cx.executor());
2096 fs.insert_tree(
2097 path!("/open_source_worktree"),
2098 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
2099 )
2100 .await;
2101 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
2102 .await;
2103
2104 let project = Project::test(
2105 fs.clone(),
2106 [
2107 path!("/open_source_worktree").as_ref(),
2108 path!("/closed_source_worktree").as_ref(),
2109 ],
2110 cx,
2111 )
2112 .await;
2113 let buffer = project
2114 .update(cx, |project, cx| {
2115 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
2116 })
2117 .await
2118 .unwrap();
2119
2120 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
2121 ep_store.update(cx, |ep_store, _cx| {
2122 ep_store.data_collection_choice = DataCollectionChoice::Enabled
2123 });
2124
2125 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2126 assert_eq!(
2127 captured_request.lock().clone().unwrap().can_collect_data,
2128 true
2129 );
2130
2131 let closed_source_file = project
2132 .update(cx, |project, cx| {
2133 let worktree2 = project
2134 .worktree_for_root_name("closed_source_worktree", cx)
2135 .unwrap();
2136 worktree2.update(cx, |worktree2, cx| {
2137 worktree2.load_file(rel_path("main.rs"), cx)
2138 })
2139 })
2140 .await
2141 .unwrap()
2142 .file;
2143
2144 buffer.update(cx, |buffer, cx| {
2145 buffer.file_updated(closed_source_file, cx);
2146 });
2147
2148 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2149 assert_eq!(
2150 captured_request.lock().clone().unwrap().can_collect_data,
2151 false
2152 );
2153}
2154
2155#[gpui::test]
2156async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
2157 init_test(cx);
2158
2159 let fs = project::FakeFs::new(cx.executor());
2160 fs.insert_tree(
2161 path!("/worktree1"),
2162 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
2163 )
2164 .await;
2165 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
2166 .await;
2167
2168 let project = Project::test(
2169 fs.clone(),
2170 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
2171 cx,
2172 )
2173 .await;
2174 let buffer = project
2175 .update(cx, |project, cx| {
2176 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
2177 })
2178 .await
2179 .unwrap();
2180 let private_buffer = project
2181 .update(cx, |project, cx| {
2182 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
2183 })
2184 .await
2185 .unwrap();
2186
2187 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
2188 ep_store.update(cx, |ep_store, _cx| {
2189 ep_store.data_collection_choice = DataCollectionChoice::Enabled
2190 });
2191
2192 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2193 assert_eq!(
2194 captured_request.lock().clone().unwrap().can_collect_data,
2195 true
2196 );
2197
2198 // this has a side effect of registering the buffer to watch for edits
2199 run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
2200 assert_eq!(
2201 captured_request.lock().clone().unwrap().can_collect_data,
2202 false
2203 );
2204
2205 private_buffer.update(cx, |private_buffer, cx| {
2206 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
2207 });
2208
2209 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2210 assert_eq!(
2211 captured_request.lock().clone().unwrap().can_collect_data,
2212 false
2213 );
2214
2215 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
2216 // included
2217 buffer.update(cx, |buffer, cx| {
2218 buffer.edit(
2219 [(
2220 0..0,
2221 " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
2222 )],
2223 None,
2224 cx,
2225 );
2226 });
2227
2228 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2229 assert_eq!(
2230 captured_request.lock().clone().unwrap().can_collect_data,
2231 true
2232 );
2233}
2234
2235fn init_test(cx: &mut TestAppContext) {
2236 cx.update(|cx| {
2237 let settings_store = SettingsStore::test(cx);
2238 cx.set_global(settings_store);
2239 });
2240}
2241
2242async fn apply_edit_prediction(
2243 buffer_content: &str,
2244 completion_response: &str,
2245 cx: &mut TestAppContext,
2246) -> String {
2247 let fs = project::FakeFs::new(cx.executor());
2248 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2249 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2250 let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
2251 *response.lock() = completion_response.to_string();
2252 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2253 buffer.update(cx, |buffer, cx| {
2254 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2255 });
2256 buffer.read_with(cx, |buffer, _| buffer.text())
2257}
2258
2259async fn run_edit_prediction(
2260 buffer: &Entity<Buffer>,
2261 project: &Entity<Project>,
2262 ep_store: &Entity<EditPredictionStore>,
2263 cx: &mut TestAppContext,
2264) -> EditPrediction {
2265 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2266 ep_store.update(cx, |ep_store, cx| {
2267 ep_store.register_buffer(buffer, &project, cx)
2268 });
2269 cx.background_executor.run_until_parked();
2270 let prediction_task = ep_store.update(cx, |ep_store, cx| {
2271 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2272 });
2273 prediction_task.await.unwrap().unwrap().prediction.unwrap()
2274}
2275
2276async fn make_test_ep_store(
2277 project: &Entity<Project>,
2278 cx: &mut TestAppContext,
2279) -> (
2280 Entity<EditPredictionStore>,
2281 Arc<Mutex<Option<PredictEditsBody>>>,
2282 Arc<Mutex<String>>,
2283) {
2284 let default_response = indoc! {"
2285 ```main.rs
2286 <|start_of_file|>
2287 <|editable_region_start|>
2288 hello world
2289 <|editable_region_end|>
2290 ```"
2291 };
2292 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2293 let completion_response: Arc<Mutex<String>> =
2294 Arc::new(Mutex::new(default_response.to_string()));
2295 let http_client = FakeHttpClient::create({
2296 let captured_request = captured_request.clone();
2297 let completion_response = completion_response.clone();
2298 let mut next_request_id = 0;
2299 move |req| {
2300 let captured_request = captured_request.clone();
2301 let completion_response = completion_response.clone();
2302 async move {
2303 match (req.method(), req.uri().path()) {
2304 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2305 .status(200)
2306 .body(
2307 serde_json::to_string(&CreateLlmTokenResponse {
2308 token: LlmToken("the-llm-token".to_string()),
2309 })
2310 .unwrap()
2311 .into(),
2312 )
2313 .unwrap()),
2314 (&Method::POST, "/predict_edits/v2") => {
2315 let mut request_body = String::new();
2316 req.into_body().read_to_string(&mut request_body).await?;
2317 *captured_request.lock() =
2318 Some(serde_json::from_str(&request_body).unwrap());
2319 next_request_id += 1;
2320 Ok(http_client::Response::builder()
2321 .status(200)
2322 .body(
2323 serde_json::to_string(&PredictEditsResponse {
2324 request_id: format!("request-{next_request_id}"),
2325 output_excerpt: completion_response.lock().clone(),
2326 })
2327 .unwrap()
2328 .into(),
2329 )
2330 .unwrap())
2331 }
2332 (&Method::POST, "/predict_edits/v3") => {
2333 next_request_id += 1;
2334 Ok(http_client::Response::builder()
2335 .status(200)
2336 .body(
2337 serde_json::to_string(&PredictEditsV3Response {
2338 request_id: format!("request-{next_request_id}"),
2339 output: "hello world".to_string(),
2340 })
2341 .unwrap()
2342 .into(),
2343 )
2344 .unwrap())
2345 }
2346 _ => Ok(http_client::Response::builder()
2347 .status(404)
2348 .body("Not Found".into())
2349 .unwrap()),
2350 }
2351 }
2352 }
2353 });
2354
2355 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2356 cx.update(|cx| {
2357 RefreshLlmTokenListener::register(client.clone(), cx);
2358 });
2359 let _server = FakeServer::for_client(42, &client, cx).await;
2360
2361 let ep_store = cx.new(|cx| {
2362 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2363 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2364
2365 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2366 for worktree in worktrees {
2367 let worktree_id = worktree.read(cx).id();
2368 ep_store
2369 .get_or_init_project(project, cx)
2370 .license_detection_watchers
2371 .entry(worktree_id)
2372 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2373 }
2374
2375 ep_store
2376 });
2377
2378 (ep_store, captured_request, completion_response)
2379}
2380
2381fn to_completion_edits(
2382 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2383 buffer: &Entity<Buffer>,
2384 cx: &App,
2385) -> Vec<(Range<Anchor>, Arc<str>)> {
2386 let buffer = buffer.read(cx);
2387 iterator
2388 .into_iter()
2389 .map(|(range, text)| {
2390 (
2391 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2392 text,
2393 )
2394 })
2395 .collect()
2396}
2397
2398fn from_completion_edits(
2399 editor_edits: &[(Range<Anchor>, Arc<str>)],
2400 buffer: &Entity<Buffer>,
2401 cx: &App,
2402) -> Vec<(Range<usize>, Arc<str>)> {
2403 let buffer = buffer.read(cx);
2404 editor_edits
2405 .iter()
2406 .map(|(range, text)| {
2407 (
2408 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2409 text.clone(),
2410 )
2411 })
2412 .collect()
2413}
2414
2415#[gpui::test]
2416async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2417 init_test(cx);
2418
2419 let fs = FakeFs::new(cx.executor());
2420 fs.insert_tree(
2421 "/project",
2422 serde_json::json!({
2423 "main.rs": "fn main() {\n \n}\n"
2424 }),
2425 )
2426 .await;
2427
2428 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2429
2430 let http_client = FakeHttpClient::create(|_req| async move {
2431 Ok(gpui::http_client::Response::builder()
2432 .status(401)
2433 .body("Unauthorized".into())
2434 .unwrap())
2435 });
2436
2437 let client =
2438 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2439 cx.update(|cx| {
2440 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2441 });
2442
2443 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2444
2445 let buffer = project
2446 .update(cx, |project, cx| {
2447 let path = project
2448 .find_project_path(path!("/project/main.rs"), cx)
2449 .unwrap();
2450 project.open_buffer(path, cx)
2451 })
2452 .await
2453 .unwrap();
2454
2455 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2456 ep_store.update(cx, |ep_store, cx| {
2457 ep_store.register_buffer(&buffer, &project, cx)
2458 });
2459 cx.background_executor.run_until_parked();
2460
2461 let completion_task = ep_store.update(cx, |ep_store, cx| {
2462 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2463 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2464 });
2465
2466 let result = completion_task.await;
2467 assert!(
2468 result.is_err(),
2469 "Without authentication and without custom URL, prediction should fail"
2470 );
2471}
2472
2473#[gpui::test]
2474fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2475 let buffer = cx.new(|cx| {
2476 Buffer::local(
2477 indoc! {"
2478 zero
2479 one
2480 two
2481 three
2482 four
2483 five
2484 six
2485 seven
2486 eight
2487 nine
2488 ten
2489 eleven
2490 twelve
2491 thirteen
2492 fourteen
2493 fifteen
2494 sixteen
2495 seventeen
2496 eighteen
2497 nineteen
2498 twenty
2499 twenty-one
2500 twenty-two
2501 twenty-three
2502 twenty-four
2503 "},
2504 cx,
2505 )
2506 });
2507
2508 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2509
2510 buffer.update(cx, |buffer, cx| {
2511 let point = Point::new(12, 0);
2512 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2513 let point = Point::new(8, 0);
2514 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2515 });
2516
2517 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2518
2519 let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2520
2521 assert_eq!(
2522 diff,
2523 indoc! {"
2524 @@ -6,10 +6,12 @@
2525 five
2526 six
2527 seven
2528 +FIRST INSERTION
2529 eight
2530 nine
2531 ten
2532 eleven
2533 +SECOND INSERTION
2534 twelve
2535 thirteen
2536 fourteen
2537 "}
2538 );
2539}
2540
2541#[ctor::ctor]
2542fn init_logger() {
2543 zlog::init_test();
2544}