1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
3use client::{UserStore, test::FakeServer};
4use clock::FakeSystemClock;
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
8 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
9};
10use futures::{
11 AsyncReadExt, StreamExt,
12 channel::{mpsc, oneshot},
13};
14use gpui::App;
15use gpui::{
16 Entity, TestAppContext,
17 http_client::{FakeHttpClient, Response},
18};
19use indoc::indoc;
20use language::{Buffer, Point};
21use lsp::LanguageServerId;
22use parking_lot::Mutex;
23use pretty_assertions::{assert_eq, assert_matches};
24use project::{FakeFs, Project};
25use serde_json::json;
26use settings::SettingsStore;
27use std::{path::Path, sync::Arc, time::Duration};
28use util::path;
29use uuid::Uuid;
30use zeta_prompt::ZetaPromptInput;
31
32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
33
34#[gpui::test]
35async fn test_current_state(cx: &mut TestAppContext) {
36 let (ep_store, mut requests) = init_test_with_fake_client(cx);
37 let fs = FakeFs::new(cx.executor());
38 fs.insert_tree(
39 "/root",
40 json!({
41 "1.txt": "Hello!\nHow\nBye\n",
42 "2.txt": "Hola!\nComo\nAdios\n"
43 }),
44 )
45 .await;
46 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
47
48 let buffer1 = project
49 .update(cx, |project, cx| {
50 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
51 project.set_active_path(Some(path.clone()), cx);
52 project.open_buffer(path, cx)
53 })
54 .await
55 .unwrap();
56 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
57 let position = snapshot1.anchor_before(language::Point::new(1, 3));
58
59 ep_store.update(cx, |ep_store, cx| {
60 ep_store.register_project(&project, cx);
61 ep_store.register_buffer(&buffer1, &project, cx);
62 });
63
64 // Prediction for current file
65
66 ep_store.update(cx, |ep_store, cx| {
67 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
68 });
69 let (request, respond_tx) = requests.predict.next().await.unwrap();
70
71 respond_tx
72 .send(model_response(
73 &request,
74 indoc! {r"
75 --- a/root/1.txt
76 +++ b/root/1.txt
77 @@ ... @@
78 Hello!
79 -How
80 +How are you?
81 Bye
82 "},
83 ))
84 .unwrap();
85
86 cx.run_until_parked();
87
88 ep_store.update(cx, |ep_store, cx| {
89 let prediction = ep_store
90 .prediction_at(&buffer1, None, &project, cx)
91 .unwrap();
92 assert_matches!(prediction, BufferEditPrediction::Local { .. });
93 });
94
95 ep_store.update(cx, |ep_store, cx| {
96 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
97 });
98
99 // Prediction for diagnostic in another file
100
101 let diagnostic = lsp::Diagnostic {
102 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
103 severity: Some(lsp::DiagnosticSeverity::ERROR),
104 message: "Sentence is incomplete".to_string(),
105 ..Default::default()
106 };
107
108 project.update(cx, |project, cx| {
109 project.lsp_store().update(cx, |lsp_store, cx| {
110 lsp_store
111 .update_diagnostics(
112 LanguageServerId(0),
113 lsp::PublishDiagnosticsParams {
114 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
115 diagnostics: vec![diagnostic],
116 version: None,
117 },
118 None,
119 language::DiagnosticSourceKind::Pushed,
120 &[],
121 cx,
122 )
123 .unwrap();
124 });
125 });
126
127 let (request, respond_tx) = requests.predict.next().await.unwrap();
128 respond_tx
129 .send(model_response(
130 &request,
131 indoc! {r#"
132 --- a/root/2.txt
133 +++ b/root/2.txt
134 @@ ... @@
135 Hola!
136 -Como
137 +Como estas?
138 Adios
139 "#},
140 ))
141 .unwrap();
142 cx.run_until_parked();
143
144 ep_store.update(cx, |ep_store, cx| {
145 let prediction = ep_store
146 .prediction_at(&buffer1, None, &project, cx)
147 .unwrap();
148 assert_matches!(
149 prediction,
150 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
151 );
152 });
153
154 let buffer2 = project
155 .update(cx, |project, cx| {
156 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
157 project.open_buffer(path, cx)
158 })
159 .await
160 .unwrap();
161
162 ep_store.update(cx, |ep_store, cx| {
163 let prediction = ep_store
164 .prediction_at(&buffer2, None, &project, cx)
165 .unwrap();
166 assert_matches!(prediction, BufferEditPrediction::Local { .. });
167 });
168}
169
170#[gpui::test]
171async fn test_simple_request(cx: &mut TestAppContext) {
172 let (ep_store, mut requests) = init_test_with_fake_client(cx);
173 let fs = FakeFs::new(cx.executor());
174 fs.insert_tree(
175 "/root",
176 json!({
177 "foo.md": "Hello!\nHow\nBye\n"
178 }),
179 )
180 .await;
181 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
182
183 let buffer = project
184 .update(cx, |project, cx| {
185 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
186 project.open_buffer(path, cx)
187 })
188 .await
189 .unwrap();
190 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
191 let position = snapshot.anchor_before(language::Point::new(1, 3));
192
193 let prediction_task = ep_store.update(cx, |ep_store, cx| {
194 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
195 });
196
197 let (request, respond_tx) = requests.predict.next().await.unwrap();
198
199 // TODO Put back when we have a structured request again
200 // assert_eq!(
201 // request.excerpt_path.as_ref(),
202 // Path::new(path!("root/foo.md"))
203 // );
204 // assert_eq!(
205 // request.cursor_point,
206 // Point {
207 // line: Line(1),
208 // column: 3
209 // }
210 // );
211
212 respond_tx
213 .send(model_response(
214 &request,
215 indoc! { r"
216 --- a/root/foo.md
217 +++ b/root/foo.md
218 @@ ... @@
219 Hello!
220 -How
221 +How are you?
222 Bye
223 "},
224 ))
225 .unwrap();
226
227 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
228
229 assert_eq!(prediction.edits.len(), 1);
230 assert_eq!(
231 prediction.edits[0].0.to_point(&snapshot).start,
232 language::Point::new(1, 3)
233 );
234 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
235}
236
237#[gpui::test]
238async fn test_request_events(cx: &mut TestAppContext) {
239 let (ep_store, mut requests) = init_test_with_fake_client(cx);
240 let fs = FakeFs::new(cx.executor());
241 fs.insert_tree(
242 "/root",
243 json!({
244 "foo.md": "Hello!\n\nBye\n"
245 }),
246 )
247 .await;
248 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
249
250 let buffer = project
251 .update(cx, |project, cx| {
252 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
253 project.open_buffer(path, cx)
254 })
255 .await
256 .unwrap();
257
258 ep_store.update(cx, |ep_store, cx| {
259 ep_store.register_buffer(&buffer, &project, cx);
260 });
261
262 buffer.update(cx, |buffer, cx| {
263 buffer.edit(vec![(7..7, "How")], None, cx);
264 });
265
266 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
267 let position = snapshot.anchor_before(language::Point::new(1, 3));
268
269 let prediction_task = ep_store.update(cx, |ep_store, cx| {
270 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
271 });
272
273 let (request, respond_tx) = requests.predict.next().await.unwrap();
274
275 let prompt = prompt_from_request(&request);
276 assert!(
277 prompt.contains(indoc! {"
278 --- a/root/foo.md
279 +++ b/root/foo.md
280 @@ -1,3 +1,3 @@
281 Hello!
282 -
283 +How
284 Bye
285 "}),
286 "{prompt}"
287 );
288
289 respond_tx
290 .send(model_response(
291 &request,
292 indoc! {r#"
293 --- a/root/foo.md
294 +++ b/root/foo.md
295 @@ ... @@
296 Hello!
297 -How
298 +How are you?
299 Bye
300 "#},
301 ))
302 .unwrap();
303
304 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
305
306 assert_eq!(prediction.edits.len(), 1);
307 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
308}
309
310#[gpui::test]
311async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
312 let (ep_store, _requests) = init_test_with_fake_client(cx);
313 let fs = FakeFs::new(cx.executor());
314 fs.insert_tree(
315 "/root",
316 json!({
317 "foo.md": "Hello!\n\nBye\n"
318 }),
319 )
320 .await;
321 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
322
323 let buffer = project
324 .update(cx, |project, cx| {
325 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
326 project.open_buffer(path, cx)
327 })
328 .await
329 .unwrap();
330
331 ep_store.update(cx, |ep_store, cx| {
332 ep_store.register_buffer(&buffer, &project, cx);
333 });
334
335 // First burst: insert "How"
336 buffer.update(cx, |buffer, cx| {
337 buffer.edit(vec![(7..7, "How")], None, cx);
338 });
339
340 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
341 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
342 cx.run_until_parked();
343
344 // Second burst: append " are you?" immediately after "How" on the same line.
345 //
346 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
347 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
348 buffer.update(cx, |buffer, cx| {
349 buffer.edit(vec![(10..10, " are you?")], None, cx);
350 });
351
352 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
353 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
354 buffer.update(cx, |buffer, cx| {
355 buffer.edit(vec![(19..19, "!")], None, cx);
356 });
357
358 // With time-based splitting, there are two distinct events.
359 let events = ep_store.update(cx, |ep_store, cx| {
360 ep_store.edit_history_for_project(&project, cx)
361 });
362 assert_eq!(events.len(), 2);
363 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
364 assert_eq!(
365 diff.as_str(),
366 indoc! {"
367 @@ -1,3 +1,3 @@
368 Hello!
369 -
370 +How
371 Bye
372 "}
373 );
374
375 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
376 assert_eq!(
377 diff.as_str(),
378 indoc! {"
379 @@ -1,3 +1,3 @@
380 Hello!
381 -How
382 +How are you?!
383 Bye
384 "}
385 );
386}
387
388#[gpui::test]
389async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
390 let (ep_store, _requests) = init_test_with_fake_client(cx);
391 let fs = FakeFs::new(cx.executor());
392
393 // Create a file with 30 lines to test line-based coalescing
394 let content = (1..=30)
395 .map(|i| format!("Line {}\n", i))
396 .collect::<String>();
397 fs.insert_tree(
398 "/root",
399 json!({
400 "foo.md": content
401 }),
402 )
403 .await;
404 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
405
406 let buffer = project
407 .update(cx, |project, cx| {
408 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
409 project.open_buffer(path, cx)
410 })
411 .await
412 .unwrap();
413
414 ep_store.update(cx, |ep_store, cx| {
415 ep_store.register_buffer(&buffer, &project, cx);
416 });
417
418 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
419 buffer.update(cx, |buffer, cx| {
420 let start = Point::new(10, 0).to_offset(buffer);
421 let end = Point::new(13, 0).to_offset(buffer);
422 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
423 });
424
425 let events = ep_store.update(cx, |ep_store, cx| {
426 ep_store.edit_history_for_project(&project, cx)
427 });
428 assert_eq!(
429 render_events(&events),
430 indoc! {"
431 @@ -8,9 +8,8 @@
432 Line 8
433 Line 9
434 Line 10
435 -Line 11
436 -Line 12
437 -Line 13
438 +Middle A
439 +Middle B
440 Line 14
441 Line 15
442 Line 16
443 "},
444 "After first edit"
445 );
446
447 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
448 // This tests that coalescing considers the START of the existing range
449 buffer.update(cx, |buffer, cx| {
450 let offset = Point::new(5, 0).to_offset(buffer);
451 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
452 });
453
454 let events = ep_store.update(cx, |ep_store, cx| {
455 ep_store.edit_history_for_project(&project, cx)
456 });
457 assert_eq!(
458 render_events(&events),
459 indoc! {"
460 @@ -3,14 +3,14 @@
461 Line 3
462 Line 4
463 Line 5
464 +Above
465 Line 6
466 Line 7
467 Line 8
468 Line 9
469 Line 10
470 -Line 11
471 -Line 12
472 -Line 13
473 +Middle A
474 +Middle B
475 Line 14
476 Line 15
477 Line 16
478 "},
479 "After inserting above (should coalesce)"
480 );
481
482 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
483 // This tests that coalescing considers the END of the existing range
484 buffer.update(cx, |buffer, cx| {
485 let offset = Point::new(14, 0).to_offset(buffer);
486 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
487 });
488
489 let events = ep_store.update(cx, |ep_store, cx| {
490 ep_store.edit_history_for_project(&project, cx)
491 });
492 assert_eq!(
493 render_events(&events),
494 indoc! {"
495 @@ -3,15 +3,16 @@
496 Line 3
497 Line 4
498 Line 5
499 +Above
500 Line 6
501 Line 7
502 Line 8
503 Line 9
504 Line 10
505 -Line 11
506 -Line 12
507 -Line 13
508 +Middle A
509 +Middle B
510 Line 14
511 +Below
512 Line 15
513 Line 16
514 Line 17
515 "},
516 "After inserting below (should coalesce)"
517 );
518
519 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
520 // This should NOT coalesce - creates a new event
521 buffer.update(cx, |buffer, cx| {
522 let offset = Point::new(25, 0).to_offset(buffer);
523 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
524 });
525
526 let events = ep_store.update(cx, |ep_store, cx| {
527 ep_store.edit_history_for_project(&project, cx)
528 });
529 assert_eq!(
530 render_events(&events),
531 indoc! {"
532 @@ -3,15 +3,16 @@
533 Line 3
534 Line 4
535 Line 5
536 +Above
537 Line 6
538 Line 7
539 Line 8
540 Line 9
541 Line 10
542 -Line 11
543 -Line 12
544 -Line 13
545 +Middle A
546 +Middle B
547 Line 14
548 +Below
549 Line 15
550 Line 16
551 Line 17
552
553 ---
554 @@ -23,6 +23,7 @@
555 Line 22
556 Line 23
557 Line 24
558 +Far below
559 Line 25
560 Line 26
561 Line 27
562 "},
563 "After inserting far below (should NOT coalesce)"
564 );
565}
566
567fn render_events(events: &[StoredEvent]) -> String {
568 events
569 .iter()
570 .map(|e| {
571 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
572 diff.as_str()
573 })
574 .collect::<Vec<_>>()
575 .join("\n---\n")
576}
577
578fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
579 events
580 .iter()
581 .map(|e| {
582 let zeta_prompt::Event::BufferChange {
583 diff, predicted, ..
584 } = e.event.as_ref();
585 let prefix = if *predicted { "predicted" } else { "manual" };
586 format!("{}\n{}", prefix, diff)
587 })
588 .collect()
589}
590
591#[gpui::test]
592async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
593 let (ep_store, _requests) = init_test_with_fake_client(cx);
594 let fs = FakeFs::new(cx.executor());
595 fs.insert_tree(
596 "/root",
597 json!({
598 "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
599 }),
600 )
601 .await;
602 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
603
604 let buffer = project
605 .update(cx, |project, cx| {
606 let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
607 project.open_buffer(path, cx)
608 })
609 .await
610 .unwrap();
611
612 ep_store.update(cx, |ep_store, cx| {
613 ep_store.register_buffer(&buffer, &project, cx);
614 });
615
616 // Case 1: Manual edits have `predicted` set to false.
617 buffer.update(cx, |buffer, cx| {
618 buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
619 });
620
621 let events = ep_store.update(cx, |ep_store, cx| {
622 ep_store.edit_history_for_project(&project, cx)
623 });
624
625 assert_eq!(
626 render_events_with_predicted(&events),
627 vec![indoc! {"
628 manual
629 @@ -1,4 +1,4 @@
630 -line 0
631 +LINE ZERO
632 line 1
633 line 2
634 line 3
635 "}]
636 );
637
638 // Case 2: Multiple successive manual edits near each other are merged into one
639 // event with `predicted` set to false.
640 buffer.update(cx, |buffer, cx| {
641 let offset = Point::new(1, 0).to_offset(buffer);
642 let end = Point::new(1, 6).to_offset(buffer);
643 buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
644 });
645
646 let events = ep_store.update(cx, |ep_store, cx| {
647 ep_store.edit_history_for_project(&project, cx)
648 });
649 assert_eq!(
650 render_events_with_predicted(&events),
651 vec![indoc! {"
652 manual
653 @@ -1,5 +1,5 @@
654 -line 0
655 -line 1
656 +LINE ZERO
657 +LINE ONE
658 line 2
659 line 3
660 line 4
661 "}]
662 );
663
664 // Case 3: Accepted predictions have `predicted` set to true.
665 // Case 5: A manual edit that follows a predicted edit is not merged with the
666 // predicted edit, even if it is nearby.
667 ep_store.update(cx, |ep_store, cx| {
668 buffer.update(cx, |buffer, cx| {
669 let offset = Point::new(2, 0).to_offset(buffer);
670 let end = Point::new(2, 6).to_offset(buffer);
671 buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
672 });
673 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
674 });
675
676 let events = ep_store.update(cx, |ep_store, cx| {
677 ep_store.edit_history_for_project(&project, cx)
678 });
679 assert_eq!(
680 render_events_with_predicted(&events),
681 vec![
682 indoc! {"
683 manual
684 @@ -1,5 +1,5 @@
685 -line 0
686 -line 1
687 +LINE ZERO
688 +LINE ONE
689 line 2
690 line 3
691 line 4
692 "},
693 indoc! {"
694 predicted
695 @@ -1,6 +1,6 @@
696 LINE ZERO
697 LINE ONE
698 -line 2
699 +LINE TWO
700 line 3
701 line 4
702 line 5
703 "}
704 ]
705 );
706
707 // Case 4: Multiple successive accepted predictions near each other are merged
708 // into one event with `predicted` set to true.
709 ep_store.update(cx, |ep_store, cx| {
710 buffer.update(cx, |buffer, cx| {
711 let offset = Point::new(3, 0).to_offset(buffer);
712 let end = Point::new(3, 6).to_offset(buffer);
713 buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
714 });
715 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
716 });
717
718 let events = ep_store.update(cx, |ep_store, cx| {
719 ep_store.edit_history_for_project(&project, cx)
720 });
721 assert_eq!(
722 render_events_with_predicted(&events),
723 vec![
724 indoc! {"
725 manual
726 @@ -1,5 +1,5 @@
727 -line 0
728 -line 1
729 +LINE ZERO
730 +LINE ONE
731 line 2
732 line 3
733 line 4
734 "},
735 indoc! {"
736 predicted
737 @@ -1,7 +1,7 @@
738 LINE ZERO
739 LINE ONE
740 -line 2
741 -line 3
742 +LINE TWO
743 +LINE THREE
744 line 4
745 line 5
746 line 6
747 "}
748 ]
749 );
750
751 // Case 5 (continued): A manual edit that follows a predicted edit is not merged
752 // with the predicted edit, even if it is nearby.
753 buffer.update(cx, |buffer, cx| {
754 let offset = Point::new(4, 0).to_offset(buffer);
755 let end = Point::new(4, 6).to_offset(buffer);
756 buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
757 });
758
759 let events = ep_store.update(cx, |ep_store, cx| {
760 ep_store.edit_history_for_project(&project, cx)
761 });
762 assert_eq!(
763 render_events_with_predicted(&events),
764 vec![
765 indoc! {"
766 manual
767 @@ -1,5 +1,5 @@
768 -line 0
769 -line 1
770 +LINE ZERO
771 +LINE ONE
772 line 2
773 line 3
774 line 4
775 "},
776 indoc! {"
777 predicted
778 @@ -1,7 +1,7 @@
779 LINE ZERO
780 LINE ONE
781 -line 2
782 -line 3
783 +LINE TWO
784 +LINE THREE
785 line 4
786 line 5
787 line 6
788 "},
789 indoc! {"
790 manual
791 @@ -2,7 +2,7 @@
792 LINE ONE
793 LINE TWO
794 LINE THREE
795 -line 4
796 +LINE FOUR
797 line 5
798 line 6
799 line 7
800 "}
801 ]
802 );
803
804 // Case 6: If we then perform a manual edit at a *different* location (more than
805 // 8 lines away), then the edits at the prior location can be merged with each
806 // other, even if some are predicted and some are not. `predicted` means all
807 // constituent edits were predicted.
808 buffer.update(cx, |buffer, cx| {
809 let offset = Point::new(14, 0).to_offset(buffer);
810 let end = Point::new(14, 7).to_offset(buffer);
811 buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
812 });
813
814 let events = ep_store.update(cx, |ep_store, cx| {
815 ep_store.edit_history_for_project(&project, cx)
816 });
817 assert_eq!(
818 render_events_with_predicted(&events),
819 vec![
820 indoc! {"
821 manual
822 @@ -1,8 +1,8 @@
823 -line 0
824 -line 1
825 -line 2
826 -line 3
827 -line 4
828 +LINE ZERO
829 +LINE ONE
830 +LINE TWO
831 +LINE THREE
832 +LINE FOUR
833 line 5
834 line 6
835 line 7
836 "},
837 indoc! {"
838 manual
839 @@ -12,4 +12,4 @@
840 line 11
841 line 12
842 line 13
843 -line 14
844 +LINE FOURTEEN
845 "}
846 ]
847 );
848}
849
850#[gpui::test]
851async fn test_empty_prediction(cx: &mut TestAppContext) {
852 let (ep_store, mut requests) = init_test_with_fake_client(cx);
853 let fs = FakeFs::new(cx.executor());
854 fs.insert_tree(
855 "/root",
856 json!({
857 "foo.md": "Hello!\nHow\nBye\n"
858 }),
859 )
860 .await;
861 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
862
863 let buffer = project
864 .update(cx, |project, cx| {
865 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
866 project.open_buffer(path, cx)
867 })
868 .await
869 .unwrap();
870 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
871 let position = snapshot.anchor_before(language::Point::new(1, 3));
872
873 ep_store.update(cx, |ep_store, cx| {
874 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
875 });
876
877 let (request, respond_tx) = requests.predict.next().await.unwrap();
878 let response = model_response(&request, "");
879 let id = response.request_id.clone();
880 respond_tx.send(response).unwrap();
881
882 cx.run_until_parked();
883
884 ep_store.update(cx, |ep_store, cx| {
885 assert!(
886 ep_store
887 .prediction_at(&buffer, None, &project, cx)
888 .is_none()
889 );
890 });
891
892 // prediction is reported as rejected
893 let (reject_request, _) = requests.reject.next().await.unwrap();
894
895 assert_eq!(
896 &reject_request.rejections,
897 &[EditPredictionRejection {
898 request_id: id,
899 reason: EditPredictionRejectReason::Empty,
900 was_shown: false
901 }]
902 );
903}
904
905#[gpui::test]
906async fn test_interpolated_empty(cx: &mut TestAppContext) {
907 let (ep_store, mut requests) = init_test_with_fake_client(cx);
908 let fs = FakeFs::new(cx.executor());
909 fs.insert_tree(
910 "/root",
911 json!({
912 "foo.md": "Hello!\nHow\nBye\n"
913 }),
914 )
915 .await;
916 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
917
918 let buffer = project
919 .update(cx, |project, cx| {
920 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
921 project.open_buffer(path, cx)
922 })
923 .await
924 .unwrap();
925 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
926 let position = snapshot.anchor_before(language::Point::new(1, 3));
927
928 ep_store.update(cx, |ep_store, cx| {
929 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
930 });
931
932 let (request, respond_tx) = requests.predict.next().await.unwrap();
933
934 buffer.update(cx, |buffer, cx| {
935 buffer.set_text("Hello!\nHow are you?\nBye", cx);
936 });
937
938 let response = model_response(&request, SIMPLE_DIFF);
939 let id = response.request_id.clone();
940 respond_tx.send(response).unwrap();
941
942 cx.run_until_parked();
943
944 ep_store.update(cx, |ep_store, cx| {
945 assert!(
946 ep_store
947 .prediction_at(&buffer, None, &project, cx)
948 .is_none()
949 );
950 });
951
952 // prediction is reported as rejected
953 let (reject_request, _) = requests.reject.next().await.unwrap();
954
955 assert_eq!(
956 &reject_request.rejections,
957 &[EditPredictionRejection {
958 request_id: id,
959 reason: EditPredictionRejectReason::InterpolatedEmpty,
960 was_shown: false
961 }]
962 );
963}
964
965const SIMPLE_DIFF: &str = indoc! { r"
966 --- a/root/foo.md
967 +++ b/root/foo.md
968 @@ ... @@
969 Hello!
970 -How
971 +How are you?
972 Bye
973"};
974
975#[gpui::test]
976async fn test_replace_current(cx: &mut TestAppContext) {
977 let (ep_store, mut requests) = init_test_with_fake_client(cx);
978 let fs = FakeFs::new(cx.executor());
979 fs.insert_tree(
980 "/root",
981 json!({
982 "foo.md": "Hello!\nHow\nBye\n"
983 }),
984 )
985 .await;
986 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
987
988 let buffer = project
989 .update(cx, |project, cx| {
990 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
991 project.open_buffer(path, cx)
992 })
993 .await
994 .unwrap();
995 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
996 let position = snapshot.anchor_before(language::Point::new(1, 3));
997
998 ep_store.update(cx, |ep_store, cx| {
999 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1000 });
1001
1002 let (request, respond_tx) = requests.predict.next().await.unwrap();
1003 let first_response = model_response(&request, SIMPLE_DIFF);
1004 let first_id = first_response.request_id.clone();
1005 respond_tx.send(first_response).unwrap();
1006
1007 cx.run_until_parked();
1008
1009 ep_store.update(cx, |ep_store, cx| {
1010 assert_eq!(
1011 ep_store
1012 .prediction_at(&buffer, None, &project, cx)
1013 .unwrap()
1014 .id
1015 .0,
1016 first_id
1017 );
1018 });
1019
1020 // a second request is triggered
1021 ep_store.update(cx, |ep_store, cx| {
1022 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1023 });
1024
1025 let (request, respond_tx) = requests.predict.next().await.unwrap();
1026 let second_response = model_response(&request, SIMPLE_DIFF);
1027 let second_id = second_response.request_id.clone();
1028 respond_tx.send(second_response).unwrap();
1029
1030 cx.run_until_parked();
1031
1032 ep_store.update(cx, |ep_store, cx| {
1033 // second replaces first
1034 assert_eq!(
1035 ep_store
1036 .prediction_at(&buffer, None, &project, cx)
1037 .unwrap()
1038 .id
1039 .0,
1040 second_id
1041 );
1042 });
1043
1044 // first is reported as replaced
1045 let (reject_request, _) = requests.reject.next().await.unwrap();
1046
1047 assert_eq!(
1048 &reject_request.rejections,
1049 &[EditPredictionRejection {
1050 request_id: first_id,
1051 reason: EditPredictionRejectReason::Replaced,
1052 was_shown: false
1053 }]
1054 );
1055}
1056
1057#[gpui::test]
1058async fn test_current_preferred(cx: &mut TestAppContext) {
1059 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1060 let fs = FakeFs::new(cx.executor());
1061 fs.insert_tree(
1062 "/root",
1063 json!({
1064 "foo.md": "Hello!\nHow\nBye\n"
1065 }),
1066 )
1067 .await;
1068 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1069
1070 let buffer = project
1071 .update(cx, |project, cx| {
1072 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1073 project.open_buffer(path, cx)
1074 })
1075 .await
1076 .unwrap();
1077 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1078 let position = snapshot.anchor_before(language::Point::new(1, 3));
1079
1080 ep_store.update(cx, |ep_store, cx| {
1081 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1082 });
1083
1084 let (request, respond_tx) = requests.predict.next().await.unwrap();
1085 let first_response = model_response(&request, SIMPLE_DIFF);
1086 let first_id = first_response.request_id.clone();
1087 respond_tx.send(first_response).unwrap();
1088
1089 cx.run_until_parked();
1090
1091 ep_store.update(cx, |ep_store, cx| {
1092 assert_eq!(
1093 ep_store
1094 .prediction_at(&buffer, None, &project, cx)
1095 .unwrap()
1096 .id
1097 .0,
1098 first_id
1099 );
1100 });
1101
1102 // a second request is triggered
1103 ep_store.update(cx, |ep_store, cx| {
1104 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1105 });
1106
1107 let (request, respond_tx) = requests.predict.next().await.unwrap();
1108 // worse than current prediction
1109 let second_response = model_response(
1110 &request,
1111 indoc! { r"
1112 --- a/root/foo.md
1113 +++ b/root/foo.md
1114 @@ ... @@
1115 Hello!
1116 -How
1117 +How are
1118 Bye
1119 "},
1120 );
1121 let second_id = second_response.request_id.clone();
1122 respond_tx.send(second_response).unwrap();
1123
1124 cx.run_until_parked();
1125
1126 ep_store.update(cx, |ep_store, cx| {
1127 // first is preferred over second
1128 assert_eq!(
1129 ep_store
1130 .prediction_at(&buffer, None, &project, cx)
1131 .unwrap()
1132 .id
1133 .0,
1134 first_id
1135 );
1136 });
1137
1138 // second is reported as rejected
1139 let (reject_request, _) = requests.reject.next().await.unwrap();
1140
1141 assert_eq!(
1142 &reject_request.rejections,
1143 &[EditPredictionRejection {
1144 request_id: second_id,
1145 reason: EditPredictionRejectReason::CurrentPreferred,
1146 was_shown: false
1147 }]
1148 );
1149}
1150
1151#[gpui::test]
1152async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1153 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1154 let fs = FakeFs::new(cx.executor());
1155 fs.insert_tree(
1156 "/root",
1157 json!({
1158 "foo.md": "Hello!\nHow\nBye\n"
1159 }),
1160 )
1161 .await;
1162 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1163
1164 let buffer = project
1165 .update(cx, |project, cx| {
1166 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1167 project.open_buffer(path, cx)
1168 })
1169 .await
1170 .unwrap();
1171 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1172 let position = snapshot.anchor_before(language::Point::new(1, 3));
1173
1174 // start two refresh tasks
1175 ep_store.update(cx, |ep_store, cx| {
1176 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1177 });
1178
1179 let (request1, respond_first) = requests.predict.next().await.unwrap();
1180
1181 ep_store.update(cx, |ep_store, cx| {
1182 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1183 });
1184
1185 let (request, respond_second) = requests.predict.next().await.unwrap();
1186
1187 // wait for throttle
1188 cx.run_until_parked();
1189
1190 // second responds first
1191 let second_response = model_response(&request, SIMPLE_DIFF);
1192 let second_id = second_response.request_id.clone();
1193 respond_second.send(second_response).unwrap();
1194
1195 cx.run_until_parked();
1196
1197 ep_store.update(cx, |ep_store, cx| {
1198 // current prediction is second
1199 assert_eq!(
1200 ep_store
1201 .prediction_at(&buffer, None, &project, cx)
1202 .unwrap()
1203 .id
1204 .0,
1205 second_id
1206 );
1207 });
1208
1209 let first_response = model_response(&request1, SIMPLE_DIFF);
1210 let first_id = first_response.request_id.clone();
1211 respond_first.send(first_response).unwrap();
1212
1213 cx.run_until_parked();
1214
1215 ep_store.update(cx, |ep_store, cx| {
1216 // current prediction is still second, since first was cancelled
1217 assert_eq!(
1218 ep_store
1219 .prediction_at(&buffer, None, &project, cx)
1220 .unwrap()
1221 .id
1222 .0,
1223 second_id
1224 );
1225 });
1226
1227 // first is reported as rejected
1228 let (reject_request, _) = requests.reject.next().await.unwrap();
1229
1230 cx.run_until_parked();
1231
1232 assert_eq!(
1233 &reject_request.rejections,
1234 &[EditPredictionRejection {
1235 request_id: first_id,
1236 reason: EditPredictionRejectReason::Canceled,
1237 was_shown: false
1238 }]
1239 );
1240}
1241
1242#[gpui::test]
1243async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1244 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1245 let fs = FakeFs::new(cx.executor());
1246 fs.insert_tree(
1247 "/root",
1248 json!({
1249 "foo.md": "Hello!\nHow\nBye\n"
1250 }),
1251 )
1252 .await;
1253 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1254
1255 let buffer = project
1256 .update(cx, |project, cx| {
1257 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1258 project.open_buffer(path, cx)
1259 })
1260 .await
1261 .unwrap();
1262 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1263 let position = snapshot.anchor_before(language::Point::new(1, 3));
1264
1265 // start two refresh tasks
1266 ep_store.update(cx, |ep_store, cx| {
1267 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1268 });
1269
1270 let (request1, respond_first) = requests.predict.next().await.unwrap();
1271
1272 ep_store.update(cx, |ep_store, cx| {
1273 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1274 });
1275
1276 let (request2, respond_second) = requests.predict.next().await.unwrap();
1277
1278 // wait for throttle, so requests are sent
1279 cx.run_until_parked();
1280
1281 ep_store.update(cx, |ep_store, cx| {
1282 // start a third request
1283 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1284
1285 // 2 are pending, so 2nd is cancelled
1286 assert_eq!(
1287 ep_store
1288 .get_or_init_project(&project, cx)
1289 .cancelled_predictions
1290 .iter()
1291 .copied()
1292 .collect::<Vec<_>>(),
1293 [1]
1294 );
1295 });
1296
1297 // wait for throttle
1298 cx.run_until_parked();
1299
1300 let (request3, respond_third) = requests.predict.next().await.unwrap();
1301
1302 let first_response = model_response(&request1, SIMPLE_DIFF);
1303 let first_id = first_response.request_id.clone();
1304 respond_first.send(first_response).unwrap();
1305
1306 cx.run_until_parked();
1307
1308 ep_store.update(cx, |ep_store, cx| {
1309 // current prediction is first
1310 assert_eq!(
1311 ep_store
1312 .prediction_at(&buffer, None, &project, cx)
1313 .unwrap()
1314 .id
1315 .0,
1316 first_id
1317 );
1318 });
1319
1320 let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1321 let cancelled_id = cancelled_response.request_id.clone();
1322 respond_second.send(cancelled_response).unwrap();
1323
1324 cx.run_until_parked();
1325
1326 ep_store.update(cx, |ep_store, cx| {
1327 // current prediction is still first, since second was cancelled
1328 assert_eq!(
1329 ep_store
1330 .prediction_at(&buffer, None, &project, cx)
1331 .unwrap()
1332 .id
1333 .0,
1334 first_id
1335 );
1336 });
1337
1338 let third_response = model_response(&request3, SIMPLE_DIFF);
1339 let third_response_id = third_response.request_id.clone();
1340 respond_third.send(third_response).unwrap();
1341
1342 cx.run_until_parked();
1343
1344 ep_store.update(cx, |ep_store, cx| {
1345 // third completes and replaces first
1346 assert_eq!(
1347 ep_store
1348 .prediction_at(&buffer, None, &project, cx)
1349 .unwrap()
1350 .id
1351 .0,
1352 third_response_id
1353 );
1354 });
1355
1356 // second is reported as rejected
1357 let (reject_request, _) = requests.reject.next().await.unwrap();
1358
1359 cx.run_until_parked();
1360
1361 assert_eq!(
1362 &reject_request.rejections,
1363 &[
1364 EditPredictionRejection {
1365 request_id: cancelled_id,
1366 reason: EditPredictionRejectReason::Canceled,
1367 was_shown: false
1368 },
1369 EditPredictionRejection {
1370 request_id: first_id,
1371 reason: EditPredictionRejectReason::Replaced,
1372 was_shown: false
1373 }
1374 ]
1375 );
1376}
1377
1378#[gpui::test]
1379async fn test_rejections_flushing(cx: &mut TestAppContext) {
1380 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1381
1382 ep_store.update(cx, |ep_store, cx| {
1383 ep_store.reject_prediction(
1384 EditPredictionId("test-1".into()),
1385 EditPredictionRejectReason::Discarded,
1386 false,
1387 cx,
1388 );
1389 ep_store.reject_prediction(
1390 EditPredictionId("test-2".into()),
1391 EditPredictionRejectReason::Canceled,
1392 true,
1393 cx,
1394 );
1395 });
1396
1397 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1398 cx.run_until_parked();
1399
1400 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1401 respond_tx.send(()).unwrap();
1402
1403 // batched
1404 assert_eq!(reject_request.rejections.len(), 2);
1405 assert_eq!(
1406 reject_request.rejections[0],
1407 EditPredictionRejection {
1408 request_id: "test-1".to_string(),
1409 reason: EditPredictionRejectReason::Discarded,
1410 was_shown: false
1411 }
1412 );
1413 assert_eq!(
1414 reject_request.rejections[1],
1415 EditPredictionRejection {
1416 request_id: "test-2".to_string(),
1417 reason: EditPredictionRejectReason::Canceled,
1418 was_shown: true
1419 }
1420 );
1421
1422 // Reaching batch size limit sends without debounce
1423 ep_store.update(cx, |ep_store, cx| {
1424 for i in 0..70 {
1425 ep_store.reject_prediction(
1426 EditPredictionId(format!("batch-{}", i).into()),
1427 EditPredictionRejectReason::Discarded,
1428 false,
1429 cx,
1430 );
1431 }
1432 });
1433
1434 // First MAX/2 items are sent immediately
1435 cx.run_until_parked();
1436 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1437 respond_tx.send(()).unwrap();
1438
1439 assert_eq!(reject_request.rejections.len(), 50);
1440 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1441 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1442
1443 // Remaining items are debounced with the next batch
1444 cx.executor().advance_clock(Duration::from_secs(15));
1445 cx.run_until_parked();
1446
1447 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1448 respond_tx.send(()).unwrap();
1449
1450 assert_eq!(reject_request.rejections.len(), 20);
1451 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1452 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1453
1454 // Request failure
1455 ep_store.update(cx, |ep_store, cx| {
1456 ep_store.reject_prediction(
1457 EditPredictionId("retry-1".into()),
1458 EditPredictionRejectReason::Discarded,
1459 false,
1460 cx,
1461 );
1462 });
1463
1464 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1465 cx.run_until_parked();
1466
1467 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1468 assert_eq!(reject_request.rejections.len(), 1);
1469 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1470 // Simulate failure
1471 drop(_respond_tx);
1472
1473 // Add another rejection
1474 ep_store.update(cx, |ep_store, cx| {
1475 ep_store.reject_prediction(
1476 EditPredictionId("retry-2".into()),
1477 EditPredictionRejectReason::Discarded,
1478 false,
1479 cx,
1480 );
1481 });
1482
1483 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1484 cx.run_until_parked();
1485
1486 // Retry should include both the failed item and the new one
1487 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1488 respond_tx.send(()).unwrap();
1489
1490 assert_eq!(reject_request.rejections.len(), 2);
1491 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1492 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1493}
1494
1495// Skipped until we start including diagnostics in prompt
1496// #[gpui::test]
1497// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1498// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1499// let fs = FakeFs::new(cx.executor());
1500// fs.insert_tree(
1501// "/root",
1502// json!({
1503// "foo.md": "Hello!\nBye"
1504// }),
1505// )
1506// .await;
1507// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1508
1509// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1510// let diagnostic = lsp::Diagnostic {
1511// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1512// severity: Some(lsp::DiagnosticSeverity::ERROR),
1513// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1514// ..Default::default()
1515// };
1516
1517// project.update(cx, |project, cx| {
1518// project.lsp_store().update(cx, |lsp_store, cx| {
1519// // Create some diagnostics
1520// lsp_store
1521// .update_diagnostics(
1522// LanguageServerId(0),
1523// lsp::PublishDiagnosticsParams {
1524// uri: path_to_buffer_uri.clone(),
1525// diagnostics: vec![diagnostic],
1526// version: None,
1527// },
1528// None,
1529// language::DiagnosticSourceKind::Pushed,
1530// &[],
1531// cx,
1532// )
1533// .unwrap();
1534// });
1535// });
1536
1537// let buffer = project
1538// .update(cx, |project, cx| {
1539// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1540// project.open_buffer(path, cx)
1541// })
1542// .await
1543// .unwrap();
1544
1545// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1546// let position = snapshot.anchor_before(language::Point::new(0, 0));
1547
1548// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1549// ep_store.request_prediction(&project, &buffer, position, cx)
1550// });
1551
1552// let (request, _respond_tx) = req_rx.next().await.unwrap();
1553
1554// assert_eq!(request.diagnostic_groups.len(), 1);
1555// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1556// .unwrap();
1557// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1558// assert_eq!(
1559// value,
1560// json!({
1561// "entries": [{
1562// "range": {
1563// "start": 8,
1564// "end": 10
1565// },
1566// "diagnostic": {
1567// "source": null,
1568// "code": null,
1569// "code_description": null,
1570// "severity": 1,
1571// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1572// "markdown": null,
1573// "group_id": 0,
1574// "is_primary": true,
1575// "is_disk_based": false,
1576// "is_unnecessary": false,
1577// "source_kind": "Pushed",
1578// "data": null,
1579// "underline": true
1580// }
1581// }],
1582// "primary_ix": 0
1583// })
1584// );
1585// }
1586
1587// Generate a model response that would apply the given diff to the active file.
1588fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1589 let excerpt =
1590 request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1591 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1592
1593 PredictEditsV3Response {
1594 request_id: Uuid::new_v4().to_string(),
1595 output: new_excerpt,
1596 }
1597}
1598
1599fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1600 zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1601}
1602
1603struct RequestChannels {
1604 predict: mpsc::UnboundedReceiver<(
1605 PredictEditsV3Request,
1606 oneshot::Sender<PredictEditsV3Response>,
1607 )>,
1608 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1609}
1610
1611fn init_test_with_fake_client(
1612 cx: &mut TestAppContext,
1613) -> (Entity<EditPredictionStore>, RequestChannels) {
1614 cx.update(move |cx| {
1615 let settings_store = SettingsStore::test(cx);
1616 cx.set_global(settings_store);
1617 zlog::init_test();
1618
1619 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1620 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1621
1622 let http_client = FakeHttpClient::create({
1623 move |req| {
1624 let uri = req.uri().path().to_string();
1625 let mut body = req.into_body();
1626 let predict_req_tx = predict_req_tx.clone();
1627 let reject_req_tx = reject_req_tx.clone();
1628 async move {
1629 let resp = match uri.as_str() {
1630 "/client/llm_tokens" => serde_json::to_string(&json!({
1631 "token": "test"
1632 }))
1633 .unwrap(),
1634 "/predict_edits/v3" => {
1635 let mut buf = Vec::new();
1636 body.read_to_end(&mut buf).await.ok();
1637 let decompressed = zstd::decode_all(&buf[..]).unwrap();
1638 let req = serde_json::from_slice(&decompressed).unwrap();
1639
1640 let (res_tx, res_rx) = oneshot::channel();
1641 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1642 serde_json::to_string(&res_rx.await?).unwrap()
1643 }
1644 "/predict_edits/reject" => {
1645 let mut buf = Vec::new();
1646 body.read_to_end(&mut buf).await.ok();
1647 let req = serde_json::from_slice(&buf).unwrap();
1648
1649 let (res_tx, res_rx) = oneshot::channel();
1650 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1651 serde_json::to_string(&res_rx.await?).unwrap()
1652 }
1653 _ => {
1654 panic!("Unexpected path: {}", uri)
1655 }
1656 };
1657
1658 Ok(Response::builder().body(resp.into()).unwrap())
1659 }
1660 }
1661 });
1662
1663 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1664 client.cloud_client().set_credentials(1, "test".into());
1665
1666 language_model::init(client.clone(), cx);
1667
1668 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1669 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1670
1671 (
1672 ep_store,
1673 RequestChannels {
1674 predict: predict_req_rx,
1675 reject: reject_req_rx,
1676 },
1677 )
1678 })
1679}
1680
1681#[gpui::test]
1682async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1683 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1684 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1685 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1686 });
1687
1688 let edit_preview = cx
1689 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1690 .await;
1691
1692 let prediction = EditPrediction {
1693 edits,
1694 cursor_position: None,
1695 edit_preview,
1696 buffer: buffer.clone(),
1697 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1698 id: EditPredictionId("the-id".into()),
1699 inputs: ZetaPromptInput {
1700 events: Default::default(),
1701 related_files: Default::default(),
1702 cursor_path: Path::new("").into(),
1703 cursor_excerpt: "".into(),
1704 editable_range_in_excerpt: 0..0,
1705 cursor_offset_in_excerpt: 0,
1706 excerpt_start_row: None,
1707 excerpt_ranges: None,
1708 preferred_model: None,
1709 in_open_source_repo: false,
1710 can_collect_data: false,
1711 },
1712 buffer_snapshotted_at: Instant::now(),
1713 response_received_at: Instant::now(),
1714 };
1715
1716 cx.update(|cx| {
1717 assert_eq!(
1718 from_completion_edits(
1719 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1720 &buffer,
1721 cx
1722 ),
1723 vec![(2..5, "REM".into()), (9..11, "".into())]
1724 );
1725
1726 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1727 assert_eq!(
1728 from_completion_edits(
1729 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1730 &buffer,
1731 cx
1732 ),
1733 vec![(2..2, "REM".into()), (6..8, "".into())]
1734 );
1735
1736 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1737 assert_eq!(
1738 from_completion_edits(
1739 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1740 &buffer,
1741 cx
1742 ),
1743 vec![(2..5, "REM".into()), (9..11, "".into())]
1744 );
1745
1746 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1747 assert_eq!(
1748 from_completion_edits(
1749 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1750 &buffer,
1751 cx
1752 ),
1753 vec![(3..3, "EM".into()), (7..9, "".into())]
1754 );
1755
1756 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1757 assert_eq!(
1758 from_completion_edits(
1759 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1760 &buffer,
1761 cx
1762 ),
1763 vec![(4..4, "M".into()), (8..10, "".into())]
1764 );
1765
1766 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1767 assert_eq!(
1768 from_completion_edits(
1769 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1770 &buffer,
1771 cx
1772 ),
1773 vec![(9..11, "".into())]
1774 );
1775
1776 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1777 assert_eq!(
1778 from_completion_edits(
1779 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1780 &buffer,
1781 cx
1782 ),
1783 vec![(4..4, "M".into()), (8..10, "".into())]
1784 );
1785
1786 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1787 assert_eq!(
1788 from_completion_edits(
1789 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1790 &buffer,
1791 cx
1792 ),
1793 vec![(4..4, "M".into())]
1794 );
1795
1796 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1797 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1798 })
1799}
1800
1801#[gpui::test]
1802async fn test_clean_up_diff(cx: &mut TestAppContext) {
1803 init_test(cx);
1804
1805 assert_eq!(
1806 apply_edit_prediction(
1807 indoc! {"
1808 fn main() {
1809 let word_1 = \"lorem\";
1810 let range = word.len()..word.len();
1811 }
1812 "},
1813 indoc! {"
1814 fn main() {
1815 let word_1 = \"lorem\";
1816 let range = word_1.len()..word_1.len();
1817 }
1818 "},
1819 cx,
1820 )
1821 .await,
1822 indoc! {"
1823 fn main() {
1824 let word_1 = \"lorem\";
1825 let range = word_1.len()..word_1.len();
1826 }
1827 "},
1828 );
1829
1830 assert_eq!(
1831 apply_edit_prediction(
1832 indoc! {"
1833 fn main() {
1834 let story = \"the quick\"
1835 }
1836 "},
1837 indoc! {"
1838 fn main() {
1839 let story = \"the quick brown fox jumps over the lazy dog\";
1840 }
1841 "},
1842 cx,
1843 )
1844 .await,
1845 indoc! {"
1846 fn main() {
1847 let story = \"the quick brown fox jumps over the lazy dog\";
1848 }
1849 "},
1850 );
1851}
1852
1853#[gpui::test]
1854async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1855 init_test(cx);
1856
1857 let buffer_content = "lorem\n";
1858 let completion_response = "lorem\nipsum\n";
1859
1860 assert_eq!(
1861 apply_edit_prediction(buffer_content, completion_response, cx).await,
1862 "lorem\nipsum\n"
1863 );
1864}
1865
1866#[gpui::test]
1867async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1868 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1869 // When the buffer ends without a trailing newline, but the model returns output
1870 // with a trailing newline, zeta2 should normalize both sides before diffing
1871 // so no spurious newline is inserted.
1872 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1873 let fs = FakeFs::new(cx.executor());
1874
1875 // Single line buffer with no trailing newline
1876 fs.insert_tree(
1877 "/root",
1878 json!({
1879 "foo.txt": "hello"
1880 }),
1881 )
1882 .await;
1883 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1884
1885 let buffer = project
1886 .update(cx, |project, cx| {
1887 let path = project
1888 .find_project_path(path!("root/foo.txt"), cx)
1889 .unwrap();
1890 project.open_buffer(path, cx)
1891 })
1892 .await
1893 .unwrap();
1894
1895 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1896 let position = snapshot.anchor_before(language::Point::new(0, 5));
1897
1898 ep_store.update(cx, |ep_store, cx| {
1899 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1900 });
1901
1902 let (_request, respond_tx) = requests.predict.next().await.unwrap();
1903
1904 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1905 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1906 let response = PredictEditsV3Response {
1907 request_id: Uuid::new_v4().to_string(),
1908 output: "hello world\n".to_string(),
1909 };
1910 respond_tx.send(response).unwrap();
1911
1912 cx.run_until_parked();
1913
1914 // The prediction should insert " world" without adding a newline
1915 ep_store.update(cx, |ep_store, cx| {
1916 let prediction = ep_store
1917 .prediction_at(&buffer, None, &project, cx)
1918 .expect("should have prediction");
1919 let edits: Vec<_> = prediction
1920 .edits
1921 .iter()
1922 .map(|(range, text)| {
1923 let snapshot = buffer.read(cx).snapshot();
1924 (range.to_offset(&snapshot), text.clone())
1925 })
1926 .collect();
1927 assert_eq!(edits, vec![(5..5, " world".into())]);
1928 });
1929}
1930
1931fn init_test(cx: &mut TestAppContext) {
1932 cx.update(|cx| {
1933 let settings_store = SettingsStore::test(cx);
1934 cx.set_global(settings_store);
1935 });
1936}
1937
1938async fn apply_edit_prediction(
1939 buffer_content: &str,
1940 completion_response: &str,
1941 cx: &mut TestAppContext,
1942) -> String {
1943 let fs = project::FakeFs::new(cx.executor());
1944 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1945 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1946 let (ep_store, response) = make_test_ep_store(&project, cx).await;
1947 *response.lock() = completion_response.to_string();
1948 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1949 buffer.update(cx, |buffer, cx| {
1950 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1951 });
1952 buffer.read_with(cx, |buffer, _| buffer.text())
1953}
1954
1955async fn run_edit_prediction(
1956 buffer: &Entity<Buffer>,
1957 project: &Entity<Project>,
1958 ep_store: &Entity<EditPredictionStore>,
1959 cx: &mut TestAppContext,
1960) -> EditPrediction {
1961 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1962 ep_store.update(cx, |ep_store, cx| {
1963 ep_store.register_buffer(buffer, &project, cx)
1964 });
1965 cx.background_executor.run_until_parked();
1966 let prediction_task = ep_store.update(cx, |ep_store, cx| {
1967 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
1968 });
1969 prediction_task.await.unwrap().unwrap().prediction.unwrap()
1970}
1971
1972async fn make_test_ep_store(
1973 project: &Entity<Project>,
1974 cx: &mut TestAppContext,
1975) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
1976 let default_response = "hello world\n".to_string();
1977 let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
1978 let http_client = FakeHttpClient::create({
1979 let completion_response = completion_response.clone();
1980 let mut next_request_id = 0;
1981 move |req| {
1982 let completion_response = completion_response.clone();
1983 async move {
1984 match (req.method(), req.uri().path()) {
1985 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1986 .status(200)
1987 .body(
1988 serde_json::to_string(&CreateLlmTokenResponse {
1989 token: LlmToken("the-llm-token".to_string()),
1990 })
1991 .unwrap()
1992 .into(),
1993 )
1994 .unwrap()),
1995 (&Method::POST, "/predict_edits/v3") => {
1996 next_request_id += 1;
1997 Ok(http_client::Response::builder()
1998 .status(200)
1999 .body(
2000 serde_json::to_string(&PredictEditsV3Response {
2001 request_id: format!("request-{next_request_id}"),
2002 output: completion_response.lock().clone(),
2003 })
2004 .unwrap()
2005 .into(),
2006 )
2007 .unwrap())
2008 }
2009 _ => Ok(http_client::Response::builder()
2010 .status(404)
2011 .body("Not Found".into())
2012 .unwrap()),
2013 }
2014 }
2015 }
2016 });
2017
2018 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2019 cx.update(|cx| {
2020 RefreshLlmTokenListener::register(client.clone(), cx);
2021 });
2022 let _server = FakeServer::for_client(42, &client, cx).await;
2023
2024 let ep_store = cx.new(|cx| {
2025 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2026 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2027
2028 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2029 for worktree in worktrees {
2030 let worktree_id = worktree.read(cx).id();
2031 ep_store
2032 .get_or_init_project(project, cx)
2033 .license_detection_watchers
2034 .entry(worktree_id)
2035 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2036 }
2037
2038 ep_store
2039 });
2040
2041 (ep_store, completion_response)
2042}
2043
2044fn to_completion_edits(
2045 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2046 buffer: &Entity<Buffer>,
2047 cx: &App,
2048) -> Vec<(Range<Anchor>, Arc<str>)> {
2049 let buffer = buffer.read(cx);
2050 iterator
2051 .into_iter()
2052 .map(|(range, text)| {
2053 (
2054 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2055 text,
2056 )
2057 })
2058 .collect()
2059}
2060
2061fn from_completion_edits(
2062 editor_edits: &[(Range<Anchor>, Arc<str>)],
2063 buffer: &Entity<Buffer>,
2064 cx: &App,
2065) -> Vec<(Range<usize>, Arc<str>)> {
2066 let buffer = buffer.read(cx);
2067 editor_edits
2068 .iter()
2069 .map(|(range, text)| {
2070 (
2071 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2072 text.clone(),
2073 )
2074 })
2075 .collect()
2076}
2077
2078#[gpui::test]
2079async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2080 init_test(cx);
2081
2082 let fs = FakeFs::new(cx.executor());
2083 fs.insert_tree(
2084 "/project",
2085 serde_json::json!({
2086 "main.rs": "fn main() {\n \n}\n"
2087 }),
2088 )
2089 .await;
2090
2091 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2092
2093 let http_client = FakeHttpClient::create(|_req| async move {
2094 Ok(gpui::http_client::Response::builder()
2095 .status(401)
2096 .body("Unauthorized".into())
2097 .unwrap())
2098 });
2099
2100 let client =
2101 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2102 cx.update(|cx| {
2103 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2104 });
2105
2106 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2107
2108 let buffer = project
2109 .update(cx, |project, cx| {
2110 let path = project
2111 .find_project_path(path!("/project/main.rs"), cx)
2112 .unwrap();
2113 project.open_buffer(path, cx)
2114 })
2115 .await
2116 .unwrap();
2117
2118 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2119 ep_store.update(cx, |ep_store, cx| {
2120 ep_store.register_buffer(&buffer, &project, cx)
2121 });
2122 cx.background_executor.run_until_parked();
2123
2124 let completion_task = ep_store.update(cx, |ep_store, cx| {
2125 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2126 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2127 });
2128
2129 let result = completion_task.await;
2130 assert!(
2131 result.is_err(),
2132 "Without authentication and without custom URL, prediction should fail"
2133 );
2134}
2135
2136#[gpui::test]
2137fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2138 let buffer = cx.new(|cx| {
2139 Buffer::local(
2140 indoc! {"
2141 zero
2142 one
2143 two
2144 three
2145 four
2146 five
2147 six
2148 seven
2149 eight
2150 nine
2151 ten
2152 eleven
2153 twelve
2154 thirteen
2155 fourteen
2156 fifteen
2157 sixteen
2158 seventeen
2159 eighteen
2160 nineteen
2161 twenty
2162 twenty-one
2163 twenty-two
2164 twenty-three
2165 twenty-four
2166 "},
2167 cx,
2168 )
2169 });
2170
2171 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2172
2173 buffer.update(cx, |buffer, cx| {
2174 let point = Point::new(12, 0);
2175 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2176 let point = Point::new(8, 0);
2177 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2178 });
2179
2180 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2181
2182 let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2183
2184 assert_eq!(
2185 diff,
2186 indoc! {"
2187 @@ -6,10 +6,12 @@
2188 five
2189 six
2190 seven
2191 +FIRST INSERTION
2192 eight
2193 nine
2194 ten
2195 eleven
2196 +SECOND INSERTION
2197 twelve
2198 thirteen
2199 fourteen
2200 "}
2201 );
2202}
2203
2204#[ctor::ctor]
2205fn init_logger() {
2206 zlog::init_test();
2207}