1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
3use client::{UserStore, test::FakeServer};
4use clock::FakeSystemClock;
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
8 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
9};
10use futures::{
11 AsyncReadExt, StreamExt,
12 channel::{mpsc, oneshot},
13};
14use gpui::App;
15use gpui::{
16 Entity, TestAppContext,
17 http_client::{FakeHttpClient, Response},
18};
19use indoc::indoc;
20use language::{Buffer, Point};
21use lsp::LanguageServerId;
22use parking_lot::Mutex;
23use pretty_assertions::{assert_eq, assert_matches};
24use project::{FakeFs, Project};
25use serde_json::json;
26use settings::SettingsStore;
27use std::{path::Path, sync::Arc, time::Duration};
28use util::path;
29use uuid::Uuid;
30use zeta_prompt::ZetaPromptInput;
31
32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
33
34#[gpui::test]
35async fn test_current_state(cx: &mut TestAppContext) {
36 let (ep_store, mut requests) = init_test_with_fake_client(cx);
37 let fs = FakeFs::new(cx.executor());
38 fs.insert_tree(
39 "/root",
40 json!({
41 "1.txt": "Hello!\nHow\nBye\n",
42 "2.txt": "Hola!\nComo\nAdios\n"
43 }),
44 )
45 .await;
46 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
47
48 let buffer1 = project
49 .update(cx, |project, cx| {
50 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
51 project.set_active_path(Some(path.clone()), cx);
52 project.open_buffer(path, cx)
53 })
54 .await
55 .unwrap();
56 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
57 let position = snapshot1.anchor_before(language::Point::new(1, 3));
58
59 ep_store.update(cx, |ep_store, cx| {
60 ep_store.register_project(&project, cx);
61 ep_store.register_buffer(&buffer1, &project, cx);
62 });
63
64 // Prediction for current file
65
66 ep_store.update(cx, |ep_store, cx| {
67 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
68 });
69 let (request, respond_tx) = requests.predict.next().await.unwrap();
70
71 respond_tx
72 .send(model_response(
73 &request,
74 indoc! {r"
75 --- a/root/1.txt
76 +++ b/root/1.txt
77 @@ ... @@
78 Hello!
79 -How
80 +How are you?
81 Bye
82 "},
83 ))
84 .unwrap();
85
86 cx.run_until_parked();
87
88 ep_store.update(cx, |ep_store, cx| {
89 let prediction = ep_store
90 .prediction_at(&buffer1, None, &project, cx)
91 .unwrap();
92 assert_matches!(prediction, BufferEditPrediction::Local { .. });
93 });
94
95 ep_store.update(cx, |ep_store, cx| {
96 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
97 });
98
99 // Prediction for diagnostic in another file
100
101 let diagnostic = lsp::Diagnostic {
102 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
103 severity: Some(lsp::DiagnosticSeverity::ERROR),
104 message: "Sentence is incomplete".to_string(),
105 ..Default::default()
106 };
107
108 project.update(cx, |project, cx| {
109 project.lsp_store().update(cx, |lsp_store, cx| {
110 lsp_store
111 .update_diagnostics(
112 LanguageServerId(0),
113 lsp::PublishDiagnosticsParams {
114 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
115 diagnostics: vec![diagnostic],
116 version: None,
117 },
118 None,
119 language::DiagnosticSourceKind::Pushed,
120 &[],
121 cx,
122 )
123 .unwrap();
124 });
125 });
126
127 let (request, respond_tx) = requests.predict.next().await.unwrap();
128 respond_tx
129 .send(model_response(
130 &request,
131 indoc! {r#"
132 --- a/root/2.txt
133 +++ b/root/2.txt
134 @@ ... @@
135 Hola!
136 -Como
137 +Como estas?
138 Adios
139 "#},
140 ))
141 .unwrap();
142 cx.run_until_parked();
143
144 ep_store.update(cx, |ep_store, cx| {
145 let prediction = ep_store
146 .prediction_at(&buffer1, None, &project, cx)
147 .unwrap();
148 assert_matches!(
149 prediction,
150 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
151 );
152 });
153
154 let buffer2 = project
155 .update(cx, |project, cx| {
156 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
157 project.open_buffer(path, cx)
158 })
159 .await
160 .unwrap();
161
162 ep_store.update(cx, |ep_store, cx| {
163 let prediction = ep_store
164 .prediction_at(&buffer2, None, &project, cx)
165 .unwrap();
166 assert_matches!(prediction, BufferEditPrediction::Local { .. });
167 });
168}
169
170#[gpui::test]
171async fn test_simple_request(cx: &mut TestAppContext) {
172 let (ep_store, mut requests) = init_test_with_fake_client(cx);
173 let fs = FakeFs::new(cx.executor());
174 fs.insert_tree(
175 "/root",
176 json!({
177 "foo.md": "Hello!\nHow\nBye\n"
178 }),
179 )
180 .await;
181 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
182
183 let buffer = project
184 .update(cx, |project, cx| {
185 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
186 project.open_buffer(path, cx)
187 })
188 .await
189 .unwrap();
190 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
191 let position = snapshot.anchor_before(language::Point::new(1, 3));
192
193 let prediction_task = ep_store.update(cx, |ep_store, cx| {
194 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
195 });
196
197 let (request, respond_tx) = requests.predict.next().await.unwrap();
198
199 // TODO Put back when we have a structured request again
200 // assert_eq!(
201 // request.excerpt_path.as_ref(),
202 // Path::new(path!("root/foo.md"))
203 // );
204 // assert_eq!(
205 // request.cursor_point,
206 // Point {
207 // line: Line(1),
208 // column: 3
209 // }
210 // );
211
212 respond_tx
213 .send(model_response(
214 &request,
215 indoc! { r"
216 --- a/root/foo.md
217 +++ b/root/foo.md
218 @@ ... @@
219 Hello!
220 -How
221 +How are you?
222 Bye
223 "},
224 ))
225 .unwrap();
226
227 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
228
229 assert_eq!(prediction.edits.len(), 1);
230 assert_eq!(
231 prediction.edits[0].0.to_point(&snapshot).start,
232 language::Point::new(1, 3)
233 );
234 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
235}
236
237#[gpui::test]
238async fn test_request_events(cx: &mut TestAppContext) {
239 let (ep_store, mut requests) = init_test_with_fake_client(cx);
240 let fs = FakeFs::new(cx.executor());
241 fs.insert_tree(
242 "/root",
243 json!({
244 "foo.md": "Hello!\n\nBye\n"
245 }),
246 )
247 .await;
248 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
249
250 let buffer = project
251 .update(cx, |project, cx| {
252 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
253 project.open_buffer(path, cx)
254 })
255 .await
256 .unwrap();
257
258 ep_store.update(cx, |ep_store, cx| {
259 ep_store.register_buffer(&buffer, &project, cx);
260 });
261
262 buffer.update(cx, |buffer, cx| {
263 buffer.edit(vec![(7..7, "How")], None, cx);
264 });
265
266 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
267 let position = snapshot.anchor_before(language::Point::new(1, 3));
268
269 let prediction_task = ep_store.update(cx, |ep_store, cx| {
270 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
271 });
272
273 let (request, respond_tx) = requests.predict.next().await.unwrap();
274
275 let prompt = prompt_from_request(&request);
276 assert!(
277 prompt.contains(indoc! {"
278 --- a/root/foo.md
279 +++ b/root/foo.md
280 @@ -1,3 +1,3 @@
281 Hello!
282 -
283 +How
284 Bye
285 "}),
286 "{prompt}"
287 );
288
289 respond_tx
290 .send(model_response(
291 &request,
292 indoc! {r#"
293 --- a/root/foo.md
294 +++ b/root/foo.md
295 @@ ... @@
296 Hello!
297 -How
298 +How are you?
299 Bye
300 "#},
301 ))
302 .unwrap();
303
304 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
305
306 assert_eq!(prediction.edits.len(), 1);
307 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
308}
309
310#[gpui::test]
311async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
312 let (ep_store, _requests) = init_test_with_fake_client(cx);
313 let fs = FakeFs::new(cx.executor());
314 fs.insert_tree(
315 "/root",
316 json!({
317 "foo.md": "Hello!\n\nBye\n"
318 }),
319 )
320 .await;
321 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
322
323 let buffer = project
324 .update(cx, |project, cx| {
325 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
326 project.open_buffer(path, cx)
327 })
328 .await
329 .unwrap();
330
331 ep_store.update(cx, |ep_store, cx| {
332 ep_store.register_buffer(&buffer, &project, cx);
333 });
334
335 // First burst: insert "How"
336 buffer.update(cx, |buffer, cx| {
337 buffer.edit(vec![(7..7, "How")], None, cx);
338 });
339
340 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
341 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
342 cx.run_until_parked();
343
344 // Second burst: append " are you?" immediately after "How" on the same line.
345 //
346 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
347 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
348 buffer.update(cx, |buffer, cx| {
349 buffer.edit(vec![(10..10, " are you?")], None, cx);
350 });
351
352 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
353 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
354 buffer.update(cx, |buffer, cx| {
355 buffer.edit(vec![(19..19, "!")], None, cx);
356 });
357
358 // With time-based splitting, there are two distinct events.
359 let events = ep_store.update(cx, |ep_store, cx| {
360 ep_store.edit_history_for_project(&project, cx)
361 });
362 assert_eq!(events.len(), 2);
363 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
364 assert_eq!(
365 diff.as_str(),
366 indoc! {"
367 @@ -1,3 +1,3 @@
368 Hello!
369 -
370 +How
371 Bye
372 "}
373 );
374
375 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
376 assert_eq!(
377 diff.as_str(),
378 indoc! {"
379 @@ -1,3 +1,3 @@
380 Hello!
381 -How
382 +How are you?!
383 Bye
384 "}
385 );
386}
387
388#[gpui::test]
389async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
390 let (ep_store, _requests) = init_test_with_fake_client(cx);
391 let fs = FakeFs::new(cx.executor());
392
393 // Create a file with 30 lines to test line-based coalescing
394 let content = (1..=30)
395 .map(|i| format!("Line {}\n", i))
396 .collect::<String>();
397 fs.insert_tree(
398 "/root",
399 json!({
400 "foo.md": content
401 }),
402 )
403 .await;
404 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
405
406 let buffer = project
407 .update(cx, |project, cx| {
408 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
409 project.open_buffer(path, cx)
410 })
411 .await
412 .unwrap();
413
414 ep_store.update(cx, |ep_store, cx| {
415 ep_store.register_buffer(&buffer, &project, cx);
416 });
417
418 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
419 buffer.update(cx, |buffer, cx| {
420 let start = Point::new(10, 0).to_offset(buffer);
421 let end = Point::new(13, 0).to_offset(buffer);
422 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
423 });
424
425 let events = ep_store.update(cx, |ep_store, cx| {
426 ep_store.edit_history_for_project(&project, cx)
427 });
428 assert_eq!(
429 render_events(&events),
430 indoc! {"
431 @@ -8,9 +8,8 @@
432 Line 8
433 Line 9
434 Line 10
435 -Line 11
436 -Line 12
437 -Line 13
438 +Middle A
439 +Middle B
440 Line 14
441 Line 15
442 Line 16
443 "},
444 "After first edit"
445 );
446
447 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
448 // This tests that coalescing considers the START of the existing range
449 buffer.update(cx, |buffer, cx| {
450 let offset = Point::new(5, 0).to_offset(buffer);
451 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
452 });
453
454 let events = ep_store.update(cx, |ep_store, cx| {
455 ep_store.edit_history_for_project(&project, cx)
456 });
457 assert_eq!(
458 render_events(&events),
459 indoc! {"
460 @@ -3,14 +3,14 @@
461 Line 3
462 Line 4
463 Line 5
464 +Above
465 Line 6
466 Line 7
467 Line 8
468 Line 9
469 Line 10
470 -Line 11
471 -Line 12
472 -Line 13
473 +Middle A
474 +Middle B
475 Line 14
476 Line 15
477 Line 16
478 "},
479 "After inserting above (should coalesce)"
480 );
481
482 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
483 // This tests that coalescing considers the END of the existing range
484 buffer.update(cx, |buffer, cx| {
485 let offset = Point::new(14, 0).to_offset(buffer);
486 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
487 });
488
489 let events = ep_store.update(cx, |ep_store, cx| {
490 ep_store.edit_history_for_project(&project, cx)
491 });
492 assert_eq!(
493 render_events(&events),
494 indoc! {"
495 @@ -3,15 +3,16 @@
496 Line 3
497 Line 4
498 Line 5
499 +Above
500 Line 6
501 Line 7
502 Line 8
503 Line 9
504 Line 10
505 -Line 11
506 -Line 12
507 -Line 13
508 +Middle A
509 +Middle B
510 Line 14
511 +Below
512 Line 15
513 Line 16
514 Line 17
515 "},
516 "After inserting below (should coalesce)"
517 );
518
519 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
520 // This should NOT coalesce - creates a new event
521 buffer.update(cx, |buffer, cx| {
522 let offset = Point::new(25, 0).to_offset(buffer);
523 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
524 });
525
526 let events = ep_store.update(cx, |ep_store, cx| {
527 ep_store.edit_history_for_project(&project, cx)
528 });
529 assert_eq!(
530 render_events(&events),
531 indoc! {"
532 @@ -3,15 +3,16 @@
533 Line 3
534 Line 4
535 Line 5
536 +Above
537 Line 6
538 Line 7
539 Line 8
540 Line 9
541 Line 10
542 -Line 11
543 -Line 12
544 -Line 13
545 +Middle A
546 +Middle B
547 Line 14
548 +Below
549 Line 15
550 Line 16
551 Line 17
552
553 ---
554 @@ -23,6 +23,7 @@
555 Line 22
556 Line 23
557 Line 24
558 +Far below
559 Line 25
560 Line 26
561 Line 27
562 "},
563 "After inserting far below (should NOT coalesce)"
564 );
565}
566
567fn render_events(events: &[StoredEvent]) -> String {
568 events
569 .iter()
570 .map(|e| {
571 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
572 diff.as_str()
573 })
574 .collect::<Vec<_>>()
575 .join("\n---\n")
576}
577
578fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
579 events
580 .iter()
581 .map(|e| {
582 let zeta_prompt::Event::BufferChange {
583 diff, predicted, ..
584 } = e.event.as_ref();
585 let prefix = if *predicted { "predicted" } else { "manual" };
586 format!("{}\n{}", prefix, diff)
587 })
588 .collect()
589}
590
591#[gpui::test]
592async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
593 let (ep_store, _requests) = init_test_with_fake_client(cx);
594 let fs = FakeFs::new(cx.executor());
595 fs.insert_tree(
596 "/root",
597 json!({
598 "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
599 }),
600 )
601 .await;
602 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
603
604 let buffer = project
605 .update(cx, |project, cx| {
606 let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
607 project.open_buffer(path, cx)
608 })
609 .await
610 .unwrap();
611
612 ep_store.update(cx, |ep_store, cx| {
613 ep_store.register_buffer(&buffer, &project, cx);
614 });
615
616 // Case 1: Manual edits have `predicted` set to false.
617 buffer.update(cx, |buffer, cx| {
618 buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
619 });
620
621 let events = ep_store.update(cx, |ep_store, cx| {
622 ep_store.edit_history_for_project(&project, cx)
623 });
624
625 assert_eq!(
626 render_events_with_predicted(&events),
627 vec![indoc! {"
628 manual
629 @@ -1,4 +1,4 @@
630 -line 0
631 +LINE ZERO
632 line 1
633 line 2
634 line 3
635 "}]
636 );
637
638 // Case 2: Multiple successive manual edits near each other are merged into one
639 // event with `predicted` set to false.
640 buffer.update(cx, |buffer, cx| {
641 let offset = Point::new(1, 0).to_offset(buffer);
642 let end = Point::new(1, 6).to_offset(buffer);
643 buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
644 });
645
646 let events = ep_store.update(cx, |ep_store, cx| {
647 ep_store.edit_history_for_project(&project, cx)
648 });
649 assert_eq!(
650 render_events_with_predicted(&events),
651 vec![indoc! {"
652 manual
653 @@ -1,5 +1,5 @@
654 -line 0
655 -line 1
656 +LINE ZERO
657 +LINE ONE
658 line 2
659 line 3
660 line 4
661 "}]
662 );
663
664 // Case 3: Accepted predictions have `predicted` set to true.
665 // Case 5: A manual edit that follows a predicted edit is not merged with the
666 // predicted edit, even if it is nearby.
667 ep_store.update(cx, |ep_store, cx| {
668 buffer.update(cx, |buffer, cx| {
669 let offset = Point::new(2, 0).to_offset(buffer);
670 let end = Point::new(2, 6).to_offset(buffer);
671 buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
672 });
673 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
674 });
675
676 let events = ep_store.update(cx, |ep_store, cx| {
677 ep_store.edit_history_for_project(&project, cx)
678 });
679 assert_eq!(
680 render_events_with_predicted(&events),
681 vec![
682 indoc! {"
683 manual
684 @@ -1,5 +1,5 @@
685 -line 0
686 -line 1
687 +LINE ZERO
688 +LINE ONE
689 line 2
690 line 3
691 line 4
692 "},
693 indoc! {"
694 predicted
695 @@ -1,6 +1,6 @@
696 LINE ZERO
697 LINE ONE
698 -line 2
699 +LINE TWO
700 line 3
701 line 4
702 line 5
703 "}
704 ]
705 );
706
707 // Case 4: Multiple successive accepted predictions near each other are merged
708 // into one event with `predicted` set to true.
709 ep_store.update(cx, |ep_store, cx| {
710 buffer.update(cx, |buffer, cx| {
711 let offset = Point::new(3, 0).to_offset(buffer);
712 let end = Point::new(3, 6).to_offset(buffer);
713 buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
714 });
715 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
716 });
717
718 let events = ep_store.update(cx, |ep_store, cx| {
719 ep_store.edit_history_for_project(&project, cx)
720 });
721 assert_eq!(
722 render_events_with_predicted(&events),
723 vec![
724 indoc! {"
725 manual
726 @@ -1,5 +1,5 @@
727 -line 0
728 -line 1
729 +LINE ZERO
730 +LINE ONE
731 line 2
732 line 3
733 line 4
734 "},
735 indoc! {"
736 predicted
737 @@ -1,7 +1,7 @@
738 LINE ZERO
739 LINE ONE
740 -line 2
741 -line 3
742 +LINE TWO
743 +LINE THREE
744 line 4
745 line 5
746 line 6
747 "}
748 ]
749 );
750
751 // Case 5 (continued): A manual edit that follows a predicted edit is not merged
752 // with the predicted edit, even if it is nearby.
753 buffer.update(cx, |buffer, cx| {
754 let offset = Point::new(4, 0).to_offset(buffer);
755 let end = Point::new(4, 6).to_offset(buffer);
756 buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
757 });
758
759 let events = ep_store.update(cx, |ep_store, cx| {
760 ep_store.edit_history_for_project(&project, cx)
761 });
762 assert_eq!(
763 render_events_with_predicted(&events),
764 vec![
765 indoc! {"
766 manual
767 @@ -1,5 +1,5 @@
768 -line 0
769 -line 1
770 +LINE ZERO
771 +LINE ONE
772 line 2
773 line 3
774 line 4
775 "},
776 indoc! {"
777 predicted
778 @@ -1,7 +1,7 @@
779 LINE ZERO
780 LINE ONE
781 -line 2
782 -line 3
783 +LINE TWO
784 +LINE THREE
785 line 4
786 line 5
787 line 6
788 "},
789 indoc! {"
790 manual
791 @@ -2,7 +2,7 @@
792 LINE ONE
793 LINE TWO
794 LINE THREE
795 -line 4
796 +LINE FOUR
797 line 5
798 line 6
799 line 7
800 "}
801 ]
802 );
803
804 // Case 6: If we then perform a manual edit at a *different* location (more than
805 // 8 lines away), then the edits at the prior location can be merged with each
806 // other, even if some are predicted and some are not. `predicted` means all
807 // constituent edits were predicted.
808 buffer.update(cx, |buffer, cx| {
809 let offset = Point::new(14, 0).to_offset(buffer);
810 let end = Point::new(14, 7).to_offset(buffer);
811 buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
812 });
813
814 let events = ep_store.update(cx, |ep_store, cx| {
815 ep_store.edit_history_for_project(&project, cx)
816 });
817 assert_eq!(
818 render_events_with_predicted(&events),
819 vec![
820 indoc! {"
821 manual
822 @@ -1,8 +1,8 @@
823 -line 0
824 -line 1
825 -line 2
826 -line 3
827 -line 4
828 +LINE ZERO
829 +LINE ONE
830 +LINE TWO
831 +LINE THREE
832 +LINE FOUR
833 line 5
834 line 6
835 line 7
836 "},
837 indoc! {"
838 manual
839 @@ -12,4 +12,4 @@
840 line 11
841 line 12
842 line 13
843 -line 14
844 +LINE FOURTEEN
845 "}
846 ]
847 );
848}
849
850#[gpui::test]
851async fn test_empty_prediction(cx: &mut TestAppContext) {
852 let (ep_store, mut requests) = init_test_with_fake_client(cx);
853 let fs = FakeFs::new(cx.executor());
854 fs.insert_tree(
855 "/root",
856 json!({
857 "foo.md": "Hello!\nHow\nBye\n"
858 }),
859 )
860 .await;
861 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
862
863 let buffer = project
864 .update(cx, |project, cx| {
865 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
866 project.open_buffer(path, cx)
867 })
868 .await
869 .unwrap();
870 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
871 let position = snapshot.anchor_before(language::Point::new(1, 3));
872
873 ep_store.update(cx, |ep_store, cx| {
874 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
875 });
876
877 let (request, respond_tx) = requests.predict.next().await.unwrap();
878 let response = model_response(&request, "");
879 let id = response.request_id.clone();
880 respond_tx.send(response).unwrap();
881
882 cx.run_until_parked();
883
884 ep_store.update(cx, |ep_store, cx| {
885 assert!(
886 ep_store
887 .prediction_at(&buffer, None, &project, cx)
888 .is_none()
889 );
890 });
891
892 // prediction is reported as rejected
893 let (reject_request, _) = requests.reject.next().await.unwrap();
894
895 assert_eq!(
896 &reject_request.rejections,
897 &[EditPredictionRejection {
898 request_id: id,
899 reason: EditPredictionRejectReason::Empty,
900 was_shown: false
901 }]
902 );
903}
904
905#[gpui::test]
906async fn test_interpolated_empty(cx: &mut TestAppContext) {
907 let (ep_store, mut requests) = init_test_with_fake_client(cx);
908 let fs = FakeFs::new(cx.executor());
909 fs.insert_tree(
910 "/root",
911 json!({
912 "foo.md": "Hello!\nHow\nBye\n"
913 }),
914 )
915 .await;
916 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
917
918 let buffer = project
919 .update(cx, |project, cx| {
920 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
921 project.open_buffer(path, cx)
922 })
923 .await
924 .unwrap();
925 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
926 let position = snapshot.anchor_before(language::Point::new(1, 3));
927
928 ep_store.update(cx, |ep_store, cx| {
929 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
930 });
931
932 let (request, respond_tx) = requests.predict.next().await.unwrap();
933
934 buffer.update(cx, |buffer, cx| {
935 buffer.set_text("Hello!\nHow are you?\nBye", cx);
936 });
937
938 let response = model_response(&request, SIMPLE_DIFF);
939 let id = response.request_id.clone();
940 respond_tx.send(response).unwrap();
941
942 cx.run_until_parked();
943
944 ep_store.update(cx, |ep_store, cx| {
945 assert!(
946 ep_store
947 .prediction_at(&buffer, None, &project, cx)
948 .is_none()
949 );
950 });
951
952 // prediction is reported as rejected
953 let (reject_request, _) = requests.reject.next().await.unwrap();
954
955 assert_eq!(
956 &reject_request.rejections,
957 &[EditPredictionRejection {
958 request_id: id,
959 reason: EditPredictionRejectReason::InterpolatedEmpty,
960 was_shown: false
961 }]
962 );
963}
964
965const SIMPLE_DIFF: &str = indoc! { r"
966 --- a/root/foo.md
967 +++ b/root/foo.md
968 @@ ... @@
969 Hello!
970 -How
971 +How are you?
972 Bye
973"};
974
975#[gpui::test]
976async fn test_replace_current(cx: &mut TestAppContext) {
977 let (ep_store, mut requests) = init_test_with_fake_client(cx);
978 let fs = FakeFs::new(cx.executor());
979 fs.insert_tree(
980 "/root",
981 json!({
982 "foo.md": "Hello!\nHow\nBye\n"
983 }),
984 )
985 .await;
986 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
987
988 let buffer = project
989 .update(cx, |project, cx| {
990 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
991 project.open_buffer(path, cx)
992 })
993 .await
994 .unwrap();
995 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
996 let position = snapshot.anchor_before(language::Point::new(1, 3));
997
998 ep_store.update(cx, |ep_store, cx| {
999 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1000 });
1001
1002 let (request, respond_tx) = requests.predict.next().await.unwrap();
1003 let first_response = model_response(&request, SIMPLE_DIFF);
1004 let first_id = first_response.request_id.clone();
1005 respond_tx.send(first_response).unwrap();
1006
1007 cx.run_until_parked();
1008
1009 ep_store.update(cx, |ep_store, cx| {
1010 assert_eq!(
1011 ep_store
1012 .prediction_at(&buffer, None, &project, cx)
1013 .unwrap()
1014 .id
1015 .0,
1016 first_id
1017 );
1018 });
1019
1020 // a second request is triggered
1021 ep_store.update(cx, |ep_store, cx| {
1022 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1023 });
1024
1025 let (request, respond_tx) = requests.predict.next().await.unwrap();
1026 let second_response = model_response(&request, SIMPLE_DIFF);
1027 let second_id = second_response.request_id.clone();
1028 respond_tx.send(second_response).unwrap();
1029
1030 cx.run_until_parked();
1031
1032 ep_store.update(cx, |ep_store, cx| {
1033 // second replaces first
1034 assert_eq!(
1035 ep_store
1036 .prediction_at(&buffer, None, &project, cx)
1037 .unwrap()
1038 .id
1039 .0,
1040 second_id
1041 );
1042 });
1043
1044 // first is reported as replaced
1045 let (reject_request, _) = requests.reject.next().await.unwrap();
1046
1047 assert_eq!(
1048 &reject_request.rejections,
1049 &[EditPredictionRejection {
1050 request_id: first_id,
1051 reason: EditPredictionRejectReason::Replaced,
1052 was_shown: false
1053 }]
1054 );
1055}
1056
1057#[gpui::test]
1058async fn test_current_preferred(cx: &mut TestAppContext) {
1059 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1060 let fs = FakeFs::new(cx.executor());
1061 fs.insert_tree(
1062 "/root",
1063 json!({
1064 "foo.md": "Hello!\nHow\nBye\n"
1065 }),
1066 )
1067 .await;
1068 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1069
1070 let buffer = project
1071 .update(cx, |project, cx| {
1072 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1073 project.open_buffer(path, cx)
1074 })
1075 .await
1076 .unwrap();
1077 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1078 let position = snapshot.anchor_before(language::Point::new(1, 3));
1079
1080 ep_store.update(cx, |ep_store, cx| {
1081 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1082 });
1083
1084 let (request, respond_tx) = requests.predict.next().await.unwrap();
1085 let first_response = model_response(&request, SIMPLE_DIFF);
1086 let first_id = first_response.request_id.clone();
1087 respond_tx.send(first_response).unwrap();
1088
1089 cx.run_until_parked();
1090
1091 ep_store.update(cx, |ep_store, cx| {
1092 assert_eq!(
1093 ep_store
1094 .prediction_at(&buffer, None, &project, cx)
1095 .unwrap()
1096 .id
1097 .0,
1098 first_id
1099 );
1100 });
1101
1102 // a second request is triggered
1103 ep_store.update(cx, |ep_store, cx| {
1104 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1105 });
1106
1107 let (request, respond_tx) = requests.predict.next().await.unwrap();
1108 // worse than current prediction
1109 let second_response = model_response(
1110 &request,
1111 indoc! { r"
1112 --- a/root/foo.md
1113 +++ b/root/foo.md
1114 @@ ... @@
1115 Hello!
1116 -How
1117 +How are
1118 Bye
1119 "},
1120 );
1121 let second_id = second_response.request_id.clone();
1122 respond_tx.send(second_response).unwrap();
1123
1124 cx.run_until_parked();
1125
1126 ep_store.update(cx, |ep_store, cx| {
1127 // first is preferred over second
1128 assert_eq!(
1129 ep_store
1130 .prediction_at(&buffer, None, &project, cx)
1131 .unwrap()
1132 .id
1133 .0,
1134 first_id
1135 );
1136 });
1137
1138 // second is reported as rejected
1139 let (reject_request, _) = requests.reject.next().await.unwrap();
1140
1141 assert_eq!(
1142 &reject_request.rejections,
1143 &[EditPredictionRejection {
1144 request_id: second_id,
1145 reason: EditPredictionRejectReason::CurrentPreferred,
1146 was_shown: false
1147 }]
1148 );
1149}
1150
1151#[gpui::test]
1152async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1153 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1154 let fs = FakeFs::new(cx.executor());
1155 fs.insert_tree(
1156 "/root",
1157 json!({
1158 "foo.md": "Hello!\nHow\nBye\n"
1159 }),
1160 )
1161 .await;
1162 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1163
1164 let buffer = project
1165 .update(cx, |project, cx| {
1166 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1167 project.open_buffer(path, cx)
1168 })
1169 .await
1170 .unwrap();
1171 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1172 let position = snapshot.anchor_before(language::Point::new(1, 3));
1173
1174 // start two refresh tasks
1175 ep_store.update(cx, |ep_store, cx| {
1176 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1177 });
1178
1179 let (request1, respond_first) = requests.predict.next().await.unwrap();
1180
1181 ep_store.update(cx, |ep_store, cx| {
1182 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1183 });
1184
1185 let (request, respond_second) = requests.predict.next().await.unwrap();
1186
1187 // wait for throttle
1188 cx.run_until_parked();
1189
1190 // second responds first
1191 let second_response = model_response(&request, SIMPLE_DIFF);
1192 let second_id = second_response.request_id.clone();
1193 respond_second.send(second_response).unwrap();
1194
1195 cx.run_until_parked();
1196
1197 ep_store.update(cx, |ep_store, cx| {
1198 // current prediction is second
1199 assert_eq!(
1200 ep_store
1201 .prediction_at(&buffer, None, &project, cx)
1202 .unwrap()
1203 .id
1204 .0,
1205 second_id
1206 );
1207 });
1208
1209 let first_response = model_response(&request1, SIMPLE_DIFF);
1210 let first_id = first_response.request_id.clone();
1211 respond_first.send(first_response).unwrap();
1212
1213 cx.run_until_parked();
1214
1215 ep_store.update(cx, |ep_store, cx| {
1216 // current prediction is still second, since first was cancelled
1217 assert_eq!(
1218 ep_store
1219 .prediction_at(&buffer, None, &project, cx)
1220 .unwrap()
1221 .id
1222 .0,
1223 second_id
1224 );
1225 });
1226
1227 // first is reported as rejected
1228 let (reject_request, _) = requests.reject.next().await.unwrap();
1229
1230 cx.run_until_parked();
1231
1232 assert_eq!(
1233 &reject_request.rejections,
1234 &[EditPredictionRejection {
1235 request_id: first_id,
1236 reason: EditPredictionRejectReason::Canceled,
1237 was_shown: false
1238 }]
1239 );
1240}
1241
1242#[gpui::test]
1243async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1244 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1245 let fs = FakeFs::new(cx.executor());
1246 fs.insert_tree(
1247 "/root",
1248 json!({
1249 "foo.md": "Hello!\nHow\nBye\n"
1250 }),
1251 )
1252 .await;
1253 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1254
1255 let buffer = project
1256 .update(cx, |project, cx| {
1257 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1258 project.open_buffer(path, cx)
1259 })
1260 .await
1261 .unwrap();
1262 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1263 let position = snapshot.anchor_before(language::Point::new(1, 3));
1264
1265 // start two refresh tasks
1266 ep_store.update(cx, |ep_store, cx| {
1267 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1268 });
1269
1270 let (request1, respond_first) = requests.predict.next().await.unwrap();
1271
1272 ep_store.update(cx, |ep_store, cx| {
1273 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1274 });
1275
1276 let (request2, respond_second) = requests.predict.next().await.unwrap();
1277
1278 // wait for throttle, so requests are sent
1279 cx.run_until_parked();
1280
1281 ep_store.update(cx, |ep_store, cx| {
1282 // start a third request
1283 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1284
1285 // 2 are pending, so 2nd is cancelled
1286 assert_eq!(
1287 ep_store
1288 .get_or_init_project(&project, cx)
1289 .cancelled_predictions
1290 .iter()
1291 .copied()
1292 .collect::<Vec<_>>(),
1293 [1]
1294 );
1295 });
1296
1297 // wait for throttle
1298 cx.run_until_parked();
1299
1300 let (request3, respond_third) = requests.predict.next().await.unwrap();
1301
1302 let first_response = model_response(&request1, SIMPLE_DIFF);
1303 let first_id = first_response.request_id.clone();
1304 respond_first.send(first_response).unwrap();
1305
1306 cx.run_until_parked();
1307
1308 ep_store.update(cx, |ep_store, cx| {
1309 // current prediction is first
1310 assert_eq!(
1311 ep_store
1312 .prediction_at(&buffer, None, &project, cx)
1313 .unwrap()
1314 .id
1315 .0,
1316 first_id
1317 );
1318 });
1319
1320 let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1321 let cancelled_id = cancelled_response.request_id.clone();
1322 respond_second.send(cancelled_response).unwrap();
1323
1324 cx.run_until_parked();
1325
1326 ep_store.update(cx, |ep_store, cx| {
1327 // current prediction is still first, since second was cancelled
1328 assert_eq!(
1329 ep_store
1330 .prediction_at(&buffer, None, &project, cx)
1331 .unwrap()
1332 .id
1333 .0,
1334 first_id
1335 );
1336 });
1337
1338 let third_response = model_response(&request3, SIMPLE_DIFF);
1339 let third_response_id = third_response.request_id.clone();
1340 respond_third.send(third_response).unwrap();
1341
1342 cx.run_until_parked();
1343
1344 ep_store.update(cx, |ep_store, cx| {
1345 // third completes and replaces first
1346 assert_eq!(
1347 ep_store
1348 .prediction_at(&buffer, None, &project, cx)
1349 .unwrap()
1350 .id
1351 .0,
1352 third_response_id
1353 );
1354 });
1355
1356 // second is reported as rejected
1357 let (reject_request, _) = requests.reject.next().await.unwrap();
1358
1359 cx.run_until_parked();
1360
1361 assert_eq!(
1362 &reject_request.rejections,
1363 &[
1364 EditPredictionRejection {
1365 request_id: cancelled_id,
1366 reason: EditPredictionRejectReason::Canceled,
1367 was_shown: false
1368 },
1369 EditPredictionRejection {
1370 request_id: first_id,
1371 reason: EditPredictionRejectReason::Replaced,
1372 was_shown: false
1373 }
1374 ]
1375 );
1376}
1377
1378#[gpui::test]
1379async fn test_rejections_flushing(cx: &mut TestAppContext) {
1380 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1381
1382 ep_store.update(cx, |ep_store, cx| {
1383 ep_store.reject_prediction(
1384 EditPredictionId("test-1".into()),
1385 EditPredictionRejectReason::Discarded,
1386 false,
1387 cx,
1388 );
1389 ep_store.reject_prediction(
1390 EditPredictionId("test-2".into()),
1391 EditPredictionRejectReason::Canceled,
1392 true,
1393 cx,
1394 );
1395 });
1396
1397 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1398 cx.run_until_parked();
1399
1400 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1401 respond_tx.send(()).unwrap();
1402
1403 // batched
1404 assert_eq!(reject_request.rejections.len(), 2);
1405 assert_eq!(
1406 reject_request.rejections[0],
1407 EditPredictionRejection {
1408 request_id: "test-1".to_string(),
1409 reason: EditPredictionRejectReason::Discarded,
1410 was_shown: false
1411 }
1412 );
1413 assert_eq!(
1414 reject_request.rejections[1],
1415 EditPredictionRejection {
1416 request_id: "test-2".to_string(),
1417 reason: EditPredictionRejectReason::Canceled,
1418 was_shown: true
1419 }
1420 );
1421
1422 // Reaching batch size limit sends without debounce
1423 ep_store.update(cx, |ep_store, cx| {
1424 for i in 0..70 {
1425 ep_store.reject_prediction(
1426 EditPredictionId(format!("batch-{}", i).into()),
1427 EditPredictionRejectReason::Discarded,
1428 false,
1429 cx,
1430 );
1431 }
1432 });
1433
1434 // First MAX/2 items are sent immediately
1435 cx.run_until_parked();
1436 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1437 respond_tx.send(()).unwrap();
1438
1439 assert_eq!(reject_request.rejections.len(), 50);
1440 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1441 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1442
1443 // Remaining items are debounced with the next batch
1444 cx.executor().advance_clock(Duration::from_secs(15));
1445 cx.run_until_parked();
1446
1447 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1448 respond_tx.send(()).unwrap();
1449
1450 assert_eq!(reject_request.rejections.len(), 20);
1451 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1452 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1453
1454 // Request failure
1455 ep_store.update(cx, |ep_store, cx| {
1456 ep_store.reject_prediction(
1457 EditPredictionId("retry-1".into()),
1458 EditPredictionRejectReason::Discarded,
1459 false,
1460 cx,
1461 );
1462 });
1463
1464 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1465 cx.run_until_parked();
1466
1467 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1468 assert_eq!(reject_request.rejections.len(), 1);
1469 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1470 // Simulate failure
1471 drop(_respond_tx);
1472
1473 // Add another rejection
1474 ep_store.update(cx, |ep_store, cx| {
1475 ep_store.reject_prediction(
1476 EditPredictionId("retry-2".into()),
1477 EditPredictionRejectReason::Discarded,
1478 false,
1479 cx,
1480 );
1481 });
1482
1483 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1484 cx.run_until_parked();
1485
1486 // Retry should include both the failed item and the new one
1487 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1488 respond_tx.send(()).unwrap();
1489
1490 assert_eq!(reject_request.rejections.len(), 2);
1491 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1492 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1493}
1494
1495// Skipped until we start including diagnostics in prompt
1496// #[gpui::test]
1497// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1498// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1499// let fs = FakeFs::new(cx.executor());
1500// fs.insert_tree(
1501// "/root",
1502// json!({
1503// "foo.md": "Hello!\nBye"
1504// }),
1505// )
1506// .await;
1507// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1508
1509// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1510// let diagnostic = lsp::Diagnostic {
1511// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1512// severity: Some(lsp::DiagnosticSeverity::ERROR),
1513// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1514// ..Default::default()
1515// };
1516
1517// project.update(cx, |project, cx| {
1518// project.lsp_store().update(cx, |lsp_store, cx| {
1519// // Create some diagnostics
1520// lsp_store
1521// .update_diagnostics(
1522// LanguageServerId(0),
1523// lsp::PublishDiagnosticsParams {
1524// uri: path_to_buffer_uri.clone(),
1525// diagnostics: vec![diagnostic],
1526// version: None,
1527// },
1528// None,
1529// language::DiagnosticSourceKind::Pushed,
1530// &[],
1531// cx,
1532// )
1533// .unwrap();
1534// });
1535// });
1536
1537// let buffer = project
1538// .update(cx, |project, cx| {
1539// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1540// project.open_buffer(path, cx)
1541// })
1542// .await
1543// .unwrap();
1544
1545// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1546// let position = snapshot.anchor_before(language::Point::new(0, 0));
1547
1548// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1549// ep_store.request_prediction(&project, &buffer, position, cx)
1550// });
1551
1552// let (request, _respond_tx) = req_rx.next().await.unwrap();
1553
1554// assert_eq!(request.diagnostic_groups.len(), 1);
1555// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1556// .unwrap();
1557// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1558// assert_eq!(
1559// value,
1560// json!({
1561// "entries": [{
1562// "range": {
1563// "start": 8,
1564// "end": 10
1565// },
1566// "diagnostic": {
1567// "source": null,
1568// "code": null,
1569// "code_description": null,
1570// "severity": 1,
1571// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1572// "markdown": null,
1573// "group_id": 0,
1574// "is_primary": true,
1575// "is_disk_based": false,
1576// "is_unnecessary": false,
1577// "source_kind": "Pushed",
1578// "data": null,
1579// "underline": true
1580// }
1581// }],
1582// "primary_ix": 0
1583// })
1584// );
1585// }
1586
1587// Generate a model response that would apply the given diff to the active file.
1588fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1589 let excerpt =
1590 request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1591 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1592
1593 PredictEditsV3Response {
1594 request_id: Uuid::new_v4().to_string(),
1595 output: new_excerpt,
1596 }
1597}
1598
1599fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1600 zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1601}
1602
1603struct RequestChannels {
1604 predict: mpsc::UnboundedReceiver<(
1605 PredictEditsV3Request,
1606 oneshot::Sender<PredictEditsV3Response>,
1607 )>,
1608 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1609}
1610
1611fn init_test_with_fake_client(
1612 cx: &mut TestAppContext,
1613) -> (Entity<EditPredictionStore>, RequestChannels) {
1614 cx.update(move |cx| {
1615 let settings_store = SettingsStore::test(cx);
1616 cx.set_global(settings_store);
1617 zlog::init_test();
1618
1619 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1620 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1621
1622 let http_client = FakeHttpClient::create({
1623 move |req| {
1624 let uri = req.uri().path().to_string();
1625 let mut body = req.into_body();
1626 let predict_req_tx = predict_req_tx.clone();
1627 let reject_req_tx = reject_req_tx.clone();
1628 async move {
1629 let resp = match uri.as_str() {
1630 "/client/llm_tokens" => serde_json::to_string(&json!({
1631 "token": "test"
1632 }))
1633 .unwrap(),
1634 "/predict_edits/v3" => {
1635 let mut buf = Vec::new();
1636 body.read_to_end(&mut buf).await.ok();
1637 let decompressed = zstd::decode_all(&buf[..]).unwrap();
1638 let req = serde_json::from_slice(&decompressed).unwrap();
1639
1640 let (res_tx, res_rx) = oneshot::channel();
1641 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1642 serde_json::to_string(&res_rx.await?).unwrap()
1643 }
1644 "/predict_edits/reject" => {
1645 let mut buf = Vec::new();
1646 body.read_to_end(&mut buf).await.ok();
1647 let req = serde_json::from_slice(&buf).unwrap();
1648
1649 let (res_tx, res_rx) = oneshot::channel();
1650 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1651 serde_json::to_string(&res_rx.await?).unwrap()
1652 }
1653 _ => {
1654 panic!("Unexpected path: {}", uri)
1655 }
1656 };
1657
1658 Ok(Response::builder().body(resp.into()).unwrap())
1659 }
1660 }
1661 });
1662
1663 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1664 client.cloud_client().set_credentials(1, "test".into());
1665
1666 language_model::init(client.clone(), cx);
1667
1668 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1669 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1670
1671 (
1672 ep_store,
1673 RequestChannels {
1674 predict: predict_req_rx,
1675 reject: reject_req_rx,
1676 },
1677 )
1678 })
1679}
1680
1681#[gpui::test]
1682async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1683 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1684 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1685 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1686 });
1687
1688 let edit_preview = cx
1689 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1690 .await;
1691
1692 let prediction = EditPrediction {
1693 edits,
1694 cursor_position: None,
1695 edit_preview,
1696 buffer: buffer.clone(),
1697 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1698 id: EditPredictionId("the-id".into()),
1699 inputs: ZetaPromptInput {
1700 events: Default::default(),
1701 related_files: Default::default(),
1702 cursor_path: Path::new("").into(),
1703 cursor_excerpt: "".into(),
1704 editable_range_in_excerpt: 0..0,
1705 cursor_offset_in_excerpt: 0,
1706 excerpt_start_row: None,
1707 excerpt_ranges: None,
1708 preferred_model: None,
1709 in_open_source_repo: false,
1710 },
1711 buffer_snapshotted_at: Instant::now(),
1712 response_received_at: Instant::now(),
1713 };
1714
1715 cx.update(|cx| {
1716 assert_eq!(
1717 from_completion_edits(
1718 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1719 &buffer,
1720 cx
1721 ),
1722 vec![(2..5, "REM".into()), (9..11, "".into())]
1723 );
1724
1725 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1726 assert_eq!(
1727 from_completion_edits(
1728 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1729 &buffer,
1730 cx
1731 ),
1732 vec![(2..2, "REM".into()), (6..8, "".into())]
1733 );
1734
1735 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1736 assert_eq!(
1737 from_completion_edits(
1738 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1739 &buffer,
1740 cx
1741 ),
1742 vec![(2..5, "REM".into()), (9..11, "".into())]
1743 );
1744
1745 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1746 assert_eq!(
1747 from_completion_edits(
1748 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1749 &buffer,
1750 cx
1751 ),
1752 vec![(3..3, "EM".into()), (7..9, "".into())]
1753 );
1754
1755 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1756 assert_eq!(
1757 from_completion_edits(
1758 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1759 &buffer,
1760 cx
1761 ),
1762 vec![(4..4, "M".into()), (8..10, "".into())]
1763 );
1764
1765 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1766 assert_eq!(
1767 from_completion_edits(
1768 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1769 &buffer,
1770 cx
1771 ),
1772 vec![(9..11, "".into())]
1773 );
1774
1775 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1776 assert_eq!(
1777 from_completion_edits(
1778 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1779 &buffer,
1780 cx
1781 ),
1782 vec![(4..4, "M".into()), (8..10, "".into())]
1783 );
1784
1785 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1786 assert_eq!(
1787 from_completion_edits(
1788 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1789 &buffer,
1790 cx
1791 ),
1792 vec![(4..4, "M".into())]
1793 );
1794
1795 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1796 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1797 })
1798}
1799
1800#[gpui::test]
1801async fn test_clean_up_diff(cx: &mut TestAppContext) {
1802 init_test(cx);
1803
1804 assert_eq!(
1805 apply_edit_prediction(
1806 indoc! {"
1807 fn main() {
1808 let word_1 = \"lorem\";
1809 let range = word.len()..word.len();
1810 }
1811 "},
1812 indoc! {"
1813 fn main() {
1814 let word_1 = \"lorem\";
1815 let range = word_1.len()..word_1.len();
1816 }
1817 "},
1818 cx,
1819 )
1820 .await,
1821 indoc! {"
1822 fn main() {
1823 let word_1 = \"lorem\";
1824 let range = word_1.len()..word_1.len();
1825 }
1826 "},
1827 );
1828
1829 assert_eq!(
1830 apply_edit_prediction(
1831 indoc! {"
1832 fn main() {
1833 let story = \"the quick\"
1834 }
1835 "},
1836 indoc! {"
1837 fn main() {
1838 let story = \"the quick brown fox jumps over the lazy dog\";
1839 }
1840 "},
1841 cx,
1842 )
1843 .await,
1844 indoc! {"
1845 fn main() {
1846 let story = \"the quick brown fox jumps over the lazy dog\";
1847 }
1848 "},
1849 );
1850}
1851
1852#[gpui::test]
1853async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1854 init_test(cx);
1855
1856 let buffer_content = "lorem\n";
1857 let completion_response = "lorem\nipsum\n";
1858
1859 assert_eq!(
1860 apply_edit_prediction(buffer_content, completion_response, cx).await,
1861 "lorem\nipsum\n"
1862 );
1863}
1864
1865#[gpui::test]
1866async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1867 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1868 // When the buffer ends without a trailing newline, but the model returns output
1869 // with a trailing newline, zeta2 should normalize both sides before diffing
1870 // so no spurious newline is inserted.
1871 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1872 let fs = FakeFs::new(cx.executor());
1873
1874 // Single line buffer with no trailing newline
1875 fs.insert_tree(
1876 "/root",
1877 json!({
1878 "foo.txt": "hello"
1879 }),
1880 )
1881 .await;
1882 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1883
1884 let buffer = project
1885 .update(cx, |project, cx| {
1886 let path = project
1887 .find_project_path(path!("root/foo.txt"), cx)
1888 .unwrap();
1889 project.open_buffer(path, cx)
1890 })
1891 .await
1892 .unwrap();
1893
1894 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1895 let position = snapshot.anchor_before(language::Point::new(0, 5));
1896
1897 ep_store.update(cx, |ep_store, cx| {
1898 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1899 });
1900
1901 let (_request, respond_tx) = requests.predict.next().await.unwrap();
1902
1903 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1904 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1905 let response = PredictEditsV3Response {
1906 request_id: Uuid::new_v4().to_string(),
1907 output: "hello world\n".to_string(),
1908 };
1909 respond_tx.send(response).unwrap();
1910
1911 cx.run_until_parked();
1912
1913 // The prediction should insert " world" without adding a newline
1914 ep_store.update(cx, |ep_store, cx| {
1915 let prediction = ep_store
1916 .prediction_at(&buffer, None, &project, cx)
1917 .expect("should have prediction");
1918 let edits: Vec<_> = prediction
1919 .edits
1920 .iter()
1921 .map(|(range, text)| {
1922 let snapshot = buffer.read(cx).snapshot();
1923 (range.to_offset(&snapshot), text.clone())
1924 })
1925 .collect();
1926 assert_eq!(edits, vec![(5..5, " world".into())]);
1927 });
1928}
1929
1930fn init_test(cx: &mut TestAppContext) {
1931 cx.update(|cx| {
1932 let settings_store = SettingsStore::test(cx);
1933 cx.set_global(settings_store);
1934 });
1935}
1936
1937async fn apply_edit_prediction(
1938 buffer_content: &str,
1939 completion_response: &str,
1940 cx: &mut TestAppContext,
1941) -> String {
1942 let fs = project::FakeFs::new(cx.executor());
1943 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1944 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1945 let (ep_store, response) = make_test_ep_store(&project, cx).await;
1946 *response.lock() = completion_response.to_string();
1947 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1948 buffer.update(cx, |buffer, cx| {
1949 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1950 });
1951 buffer.read_with(cx, |buffer, _| buffer.text())
1952}
1953
1954async fn run_edit_prediction(
1955 buffer: &Entity<Buffer>,
1956 project: &Entity<Project>,
1957 ep_store: &Entity<EditPredictionStore>,
1958 cx: &mut TestAppContext,
1959) -> EditPrediction {
1960 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1961 ep_store.update(cx, |ep_store, cx| {
1962 ep_store.register_buffer(buffer, &project, cx)
1963 });
1964 cx.background_executor.run_until_parked();
1965 let prediction_task = ep_store.update(cx, |ep_store, cx| {
1966 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
1967 });
1968 prediction_task.await.unwrap().unwrap().prediction.unwrap()
1969}
1970
1971async fn make_test_ep_store(
1972 project: &Entity<Project>,
1973 cx: &mut TestAppContext,
1974) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
1975 let default_response = "hello world\n".to_string();
1976 let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
1977 let http_client = FakeHttpClient::create({
1978 let completion_response = completion_response.clone();
1979 let mut next_request_id = 0;
1980 move |req| {
1981 let completion_response = completion_response.clone();
1982 async move {
1983 match (req.method(), req.uri().path()) {
1984 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1985 .status(200)
1986 .body(
1987 serde_json::to_string(&CreateLlmTokenResponse {
1988 token: LlmToken("the-llm-token".to_string()),
1989 })
1990 .unwrap()
1991 .into(),
1992 )
1993 .unwrap()),
1994 (&Method::POST, "/predict_edits/v3") => {
1995 next_request_id += 1;
1996 Ok(http_client::Response::builder()
1997 .status(200)
1998 .body(
1999 serde_json::to_string(&PredictEditsV3Response {
2000 request_id: format!("request-{next_request_id}"),
2001 output: completion_response.lock().clone(),
2002 })
2003 .unwrap()
2004 .into(),
2005 )
2006 .unwrap())
2007 }
2008 _ => Ok(http_client::Response::builder()
2009 .status(404)
2010 .body("Not Found".into())
2011 .unwrap()),
2012 }
2013 }
2014 }
2015 });
2016
2017 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2018 cx.update(|cx| {
2019 RefreshLlmTokenListener::register(client.clone(), cx);
2020 });
2021 let _server = FakeServer::for_client(42, &client, cx).await;
2022
2023 let ep_store = cx.new(|cx| {
2024 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2025 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2026
2027 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2028 for worktree in worktrees {
2029 let worktree_id = worktree.read(cx).id();
2030 ep_store
2031 .get_or_init_project(project, cx)
2032 .license_detection_watchers
2033 .entry(worktree_id)
2034 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2035 }
2036
2037 ep_store
2038 });
2039
2040 (ep_store, completion_response)
2041}
2042
2043fn to_completion_edits(
2044 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2045 buffer: &Entity<Buffer>,
2046 cx: &App,
2047) -> Vec<(Range<Anchor>, Arc<str>)> {
2048 let buffer = buffer.read(cx);
2049 iterator
2050 .into_iter()
2051 .map(|(range, text)| {
2052 (
2053 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2054 text,
2055 )
2056 })
2057 .collect()
2058}
2059
2060fn from_completion_edits(
2061 editor_edits: &[(Range<Anchor>, Arc<str>)],
2062 buffer: &Entity<Buffer>,
2063 cx: &App,
2064) -> Vec<(Range<usize>, Arc<str>)> {
2065 let buffer = buffer.read(cx);
2066 editor_edits
2067 .iter()
2068 .map(|(range, text)| {
2069 (
2070 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2071 text.clone(),
2072 )
2073 })
2074 .collect()
2075}
2076
2077#[gpui::test]
2078async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2079 init_test(cx);
2080
2081 let fs = FakeFs::new(cx.executor());
2082 fs.insert_tree(
2083 "/project",
2084 serde_json::json!({
2085 "main.rs": "fn main() {\n \n}\n"
2086 }),
2087 )
2088 .await;
2089
2090 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2091
2092 let http_client = FakeHttpClient::create(|_req| async move {
2093 Ok(gpui::http_client::Response::builder()
2094 .status(401)
2095 .body("Unauthorized".into())
2096 .unwrap())
2097 });
2098
2099 let client =
2100 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2101 cx.update(|cx| {
2102 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2103 });
2104
2105 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2106
2107 let buffer = project
2108 .update(cx, |project, cx| {
2109 let path = project
2110 .find_project_path(path!("/project/main.rs"), cx)
2111 .unwrap();
2112 project.open_buffer(path, cx)
2113 })
2114 .await
2115 .unwrap();
2116
2117 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2118 ep_store.update(cx, |ep_store, cx| {
2119 ep_store.register_buffer(&buffer, &project, cx)
2120 });
2121 cx.background_executor.run_until_parked();
2122
2123 let completion_task = ep_store.update(cx, |ep_store, cx| {
2124 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2125 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2126 });
2127
2128 let result = completion_task.await;
2129 assert!(
2130 result.is_err(),
2131 "Without authentication and without custom URL, prediction should fail"
2132 );
2133}
2134
2135#[gpui::test]
2136fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2137 let buffer = cx.new(|cx| {
2138 Buffer::local(
2139 indoc! {"
2140 zero
2141 one
2142 two
2143 three
2144 four
2145 five
2146 six
2147 seven
2148 eight
2149 nine
2150 ten
2151 eleven
2152 twelve
2153 thirteen
2154 fourteen
2155 fifteen
2156 sixteen
2157 seventeen
2158 eighteen
2159 nineteen
2160 twenty
2161 twenty-one
2162 twenty-two
2163 twenty-three
2164 twenty-four
2165 "},
2166 cx,
2167 )
2168 });
2169
2170 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2171
2172 buffer.update(cx, |buffer, cx| {
2173 let point = Point::new(12, 0);
2174 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2175 let point = Point::new(8, 0);
2176 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2177 });
2178
2179 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2180
2181 let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2182
2183 assert_eq!(
2184 diff,
2185 indoc! {"
2186 @@ -6,10 +6,12 @@
2187 five
2188 six
2189 seven
2190 +FIRST INSERTION
2191 eight
2192 nine
2193 ten
2194 eleven
2195 +SECOND INSERTION
2196 twelve
2197 thirteen
2198 fourteen
2199 "}
2200 );
2201}
2202
2203#[ctor::ctor]
2204fn init_logger() {
2205 zlog::init_test();
2206}