1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
3use client::{UserStore, test::FakeServer};
4use clock::FakeSystemClock;
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
8 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
9};
10use futures::{
11 AsyncReadExt, FutureExt, StreamExt,
12 channel::{mpsc, oneshot},
13};
14use gpui::App;
15use gpui::{
16 Entity, TestAppContext,
17 http_client::{FakeHttpClient, Response},
18};
19use indoc::indoc;
20use language::{
21 Anchor, Buffer, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet, DiagnosticSeverity,
22 Operation, Point, Selection, SelectionGoal,
23};
24use language_model::RefreshLlmTokenListener;
25use lsp::LanguageServerId;
26use parking_lot::Mutex;
27use pretty_assertions::{assert_eq, assert_matches};
28use project::{FakeFs, Project};
29use serde_json::json;
30use settings::SettingsStore;
31use std::{path::Path, sync::Arc, time::Duration};
32use util::{
33 path,
34 test::{TextRangeMarker, marked_text_ranges_by},
35};
36use uuid::Uuid;
37use zeta_prompt::ZetaPromptInput;
38
39use crate::{
40 BufferEditPrediction, EDIT_PREDICTION_SETTLED_QUIESCENCE, EditPredictionId,
41 EditPredictionStore, REJECT_REQUEST_DEBOUNCE,
42};
43
44#[gpui::test]
45async fn test_current_state(cx: &mut TestAppContext) {
46 let (ep_store, mut requests) = init_test_with_fake_client(cx);
47 let fs = FakeFs::new(cx.executor());
48 fs.insert_tree(
49 "/root",
50 json!({
51 "1.txt": "Hello!\nHow\nBye\n",
52 "2.txt": "Hola!\nComo\nAdios\n"
53 }),
54 )
55 .await;
56 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
57
58 let buffer1 = project
59 .update(cx, |project, cx| {
60 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
61 project.set_active_path(Some(path.clone()), cx);
62 project.open_buffer(path, cx)
63 })
64 .await
65 .unwrap();
66 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
67 let position = snapshot1.anchor_before(language::Point::new(1, 3));
68
69 ep_store.update(cx, |ep_store, cx| {
70 ep_store.register_project(&project, cx);
71 ep_store.register_buffer(&buffer1, &project, cx);
72 });
73
74 // Prediction for current file
75
76 ep_store.update(cx, |ep_store, cx| {
77 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
78 });
79 let (request, respond_tx) = requests.predict.next().await.unwrap();
80
81 respond_tx
82 .send(model_response(
83 &request,
84 indoc! {r"
85 --- a/root/1.txt
86 +++ b/root/1.txt
87 @@ ... @@
88 Hello!
89 -How
90 +How are you?
91 Bye
92 "},
93 ))
94 .unwrap();
95
96 cx.run_until_parked();
97
98 ep_store.update(cx, |ep_store, cx| {
99 let prediction = ep_store
100 .prediction_at(&buffer1, None, &project, cx)
101 .unwrap();
102 assert_matches!(prediction, BufferEditPrediction::Local { .. });
103 });
104
105 ep_store.update(cx, |ep_store, cx| {
106 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
107 });
108
109 // Prediction for diagnostic in another file
110
111 let diagnostic = lsp::Diagnostic {
112 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
113 severity: Some(lsp::DiagnosticSeverity::ERROR),
114 message: "Sentence is incomplete".to_string(),
115 ..Default::default()
116 };
117
118 project.update(cx, |project, cx| {
119 project.lsp_store().update(cx, |lsp_store, cx| {
120 lsp_store
121 .update_diagnostics(
122 LanguageServerId(0),
123 lsp::PublishDiagnosticsParams {
124 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
125 diagnostics: vec![diagnostic],
126 version: None,
127 },
128 None,
129 language::DiagnosticSourceKind::Pushed,
130 &[],
131 cx,
132 )
133 .unwrap();
134 });
135 });
136
137 let (request, respond_tx) = requests.predict.next().await.unwrap();
138 respond_tx
139 .send(model_response(
140 &request,
141 indoc! {r#"
142 --- a/root/2.txt
143 +++ b/root/2.txt
144 @@ ... @@
145 Hola!
146 -Como
147 +Como estas?
148 Adios
149 "#},
150 ))
151 .unwrap();
152 cx.run_until_parked();
153
154 ep_store.update(cx, |ep_store, cx| {
155 let prediction = ep_store
156 .prediction_at(&buffer1, None, &project, cx)
157 .unwrap();
158 assert_matches!(
159 prediction,
160 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
161 );
162 });
163
164 let buffer2 = project
165 .update(cx, |project, cx| {
166 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
167 project.open_buffer(path, cx)
168 })
169 .await
170 .unwrap();
171
172 ep_store.update(cx, |ep_store, cx| {
173 let prediction = ep_store
174 .prediction_at(&buffer2, None, &project, cx)
175 .unwrap();
176 assert_matches!(prediction, BufferEditPrediction::Local { .. });
177 });
178}
179
180#[gpui::test]
181async fn test_simple_request(cx: &mut TestAppContext) {
182 let (ep_store, mut requests) = init_test_with_fake_client(cx);
183 let fs = FakeFs::new(cx.executor());
184 fs.insert_tree(
185 "/root",
186 json!({
187 "foo.md": "Hello!\nHow\nBye\n"
188 }),
189 )
190 .await;
191 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
192
193 let buffer = project
194 .update(cx, |project, cx| {
195 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
196 project.open_buffer(path, cx)
197 })
198 .await
199 .unwrap();
200 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
201 let position = snapshot.anchor_before(language::Point::new(1, 3));
202
203 let prediction_task = ep_store.update(cx, |ep_store, cx| {
204 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
205 });
206
207 let (request, respond_tx) = requests.predict.next().await.unwrap();
208
209 // TODO Put back when we have a structured request again
210 // assert_eq!(
211 // request.excerpt_path.as_ref(),
212 // Path::new(path!("root/foo.md"))
213 // );
214 // assert_eq!(
215 // request.cursor_point,
216 // Point {
217 // line: Line(1),
218 // column: 3
219 // }
220 // );
221
222 respond_tx
223 .send(model_response(
224 &request,
225 indoc! { r"
226 --- a/root/foo.md
227 +++ b/root/foo.md
228 @@ ... @@
229 Hello!
230 -How
231 +How are you?
232 Bye
233 "},
234 ))
235 .unwrap();
236
237 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
238
239 assert_eq!(prediction.edits.len(), 1);
240 assert_eq!(
241 prediction.edits[0].0.to_point(&snapshot).start,
242 language::Point::new(1, 3)
243 );
244 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
245}
246
247#[gpui::test]
248async fn test_request_events(cx: &mut TestAppContext) {
249 let (ep_store, mut requests) = init_test_with_fake_client(cx);
250 let fs = FakeFs::new(cx.executor());
251 fs.insert_tree(
252 "/root",
253 json!({
254 "foo.md": "Hello!\n\nBye\n"
255 }),
256 )
257 .await;
258 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
259
260 let buffer = project
261 .update(cx, |project, cx| {
262 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
263 project.open_buffer(path, cx)
264 })
265 .await
266 .unwrap();
267
268 ep_store.update(cx, |ep_store, cx| {
269 ep_store.register_buffer(&buffer, &project, cx);
270 });
271
272 buffer.update(cx, |buffer, cx| {
273 buffer.edit(vec![(7..7, "How")], None, cx);
274 });
275
276 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
277 let position = snapshot.anchor_before(language::Point::new(1, 3));
278
279 let prediction_task = ep_store.update(cx, |ep_store, cx| {
280 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
281 });
282
283 let (request, respond_tx) = requests.predict.next().await.unwrap();
284
285 let prompt = prompt_from_request(&request);
286 assert!(
287 prompt.contains(indoc! {"
288 --- a/root/foo.md
289 +++ b/root/foo.md
290 @@ -1,3 +1,3 @@
291 Hello!
292 -
293 +How
294 Bye
295 "}),
296 "{prompt}"
297 );
298
299 respond_tx
300 .send(model_response(
301 &request,
302 indoc! {r#"
303 --- a/root/foo.md
304 +++ b/root/foo.md
305 @@ ... @@
306 Hello!
307 -How
308 +How are you?
309 Bye
310 "#},
311 ))
312 .unwrap();
313
314 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
315
316 assert_eq!(prediction.edits.len(), 1);
317 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
318}
319
320#[gpui::test]
321async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
322 let (ep_store, _requests) = init_test_with_fake_client(cx);
323 let fs = FakeFs::new(cx.executor());
324 fs.insert_tree(
325 "/root",
326 json!({
327 "foo.md": "Hello!\n\nBye\n"
328 }),
329 )
330 .await;
331 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
332
333 let buffer = project
334 .update(cx, |project, cx| {
335 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
336 project.open_buffer(path, cx)
337 })
338 .await
339 .unwrap();
340
341 ep_store.update(cx, |ep_store, cx| {
342 ep_store.register_buffer(&buffer, &project, cx);
343 });
344
345 // First burst: insert "How"
346 buffer.update(cx, |buffer, cx| {
347 buffer.edit(vec![(7..7, "How")], None, cx);
348 });
349
350 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
351 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
352 cx.run_until_parked();
353
354 // Second burst: append " are you?" immediately after "How" on the same line.
355 //
356 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
357 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
358 buffer.update(cx, |buffer, cx| {
359 buffer.edit(vec![(10..10, " are you?")], None, cx);
360 });
361
362 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
363 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
364 buffer.update(cx, |buffer, cx| {
365 buffer.edit(vec![(19..19, "!")], None, cx);
366 });
367
368 // With time-based splitting, there are two distinct events.
369 let events = ep_store.update(cx, |ep_store, cx| {
370 ep_store.edit_history_for_project(&project, cx)
371 });
372 assert_eq!(events.len(), 2);
373 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
374 assert_eq!(
375 diff.as_str(),
376 indoc! {"
377 @@ -1,3 +1,3 @@
378 Hello!
379 -
380 +How
381 Bye
382 "}
383 );
384
385 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
386 assert_eq!(
387 diff.as_str(),
388 indoc! {"
389 @@ -1,3 +1,3 @@
390 Hello!
391 -How
392 +How are you?!
393 Bye
394 "}
395 );
396}
397
398#[gpui::test]
399async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
400 let (ep_store, _requests) = init_test_with_fake_client(cx);
401 let fs = FakeFs::new(cx.executor());
402
403 // Create a file with 30 lines to test line-based coalescing
404 let content = (1..=30)
405 .map(|i| format!("Line {}\n", i))
406 .collect::<String>();
407 fs.insert_tree(
408 "/root",
409 json!({
410 "foo.md": content
411 }),
412 )
413 .await;
414 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
415
416 let buffer = project
417 .update(cx, |project, cx| {
418 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
419 project.open_buffer(path, cx)
420 })
421 .await
422 .unwrap();
423
424 ep_store.update(cx, |ep_store, cx| {
425 ep_store.register_buffer(&buffer, &project, cx);
426 });
427
428 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
429 buffer.update(cx, |buffer, cx| {
430 let start = Point::new(10, 0).to_offset(buffer);
431 let end = Point::new(13, 0).to_offset(buffer);
432 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
433 });
434
435 let events = ep_store.update(cx, |ep_store, cx| {
436 ep_store.edit_history_for_project(&project, cx)
437 });
438 assert_eq!(
439 render_events(&events),
440 indoc! {"
441 @@ -8,9 +8,8 @@
442 Line 8
443 Line 9
444 Line 10
445 -Line 11
446 -Line 12
447 -Line 13
448 +Middle A
449 +Middle B
450 Line 14
451 Line 15
452 Line 16
453 "},
454 "After first edit"
455 );
456
457 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
458 // This tests that coalescing considers the START of the existing range
459 buffer.update(cx, |buffer, cx| {
460 let offset = Point::new(5, 0).to_offset(buffer);
461 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
462 });
463
464 let events = ep_store.update(cx, |ep_store, cx| {
465 ep_store.edit_history_for_project(&project, cx)
466 });
467 assert_eq!(
468 render_events(&events),
469 indoc! {"
470 @@ -3,14 +3,14 @@
471 Line 3
472 Line 4
473 Line 5
474 +Above
475 Line 6
476 Line 7
477 Line 8
478 Line 9
479 Line 10
480 -Line 11
481 -Line 12
482 -Line 13
483 +Middle A
484 +Middle B
485 Line 14
486 Line 15
487 Line 16
488 "},
489 "After inserting above (should coalesce)"
490 );
491
492 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
493 // This tests that coalescing considers the END of the existing range
494 buffer.update(cx, |buffer, cx| {
495 let offset = Point::new(14, 0).to_offset(buffer);
496 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
497 });
498
499 let events = ep_store.update(cx, |ep_store, cx| {
500 ep_store.edit_history_for_project(&project, cx)
501 });
502 assert_eq!(
503 render_events(&events),
504 indoc! {"
505 @@ -3,15 +3,16 @@
506 Line 3
507 Line 4
508 Line 5
509 +Above
510 Line 6
511 Line 7
512 Line 8
513 Line 9
514 Line 10
515 -Line 11
516 -Line 12
517 -Line 13
518 +Middle A
519 +Middle B
520 Line 14
521 +Below
522 Line 15
523 Line 16
524 Line 17
525 "},
526 "After inserting below (should coalesce)"
527 );
528
529 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
530 // This should NOT coalesce - creates a new event
531 buffer.update(cx, |buffer, cx| {
532 let offset = Point::new(25, 0).to_offset(buffer);
533 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
534 });
535
536 let events = ep_store.update(cx, |ep_store, cx| {
537 ep_store.edit_history_for_project(&project, cx)
538 });
539 assert_eq!(
540 render_events(&events),
541 indoc! {"
542 @@ -3,15 +3,16 @@
543 Line 3
544 Line 4
545 Line 5
546 +Above
547 Line 6
548 Line 7
549 Line 8
550 Line 9
551 Line 10
552 -Line 11
553 -Line 12
554 -Line 13
555 +Middle A
556 +Middle B
557 Line 14
558 +Below
559 Line 15
560 Line 16
561 Line 17
562
563 ---
564 @@ -23,6 +23,7 @@
565 Line 22
566 Line 23
567 Line 24
568 +Far below
569 Line 25
570 Line 26
571 Line 27
572 "},
573 "After inserting far below (should NOT coalesce)"
574 );
575}
576
577fn render_events(events: &[StoredEvent]) -> String {
578 events
579 .iter()
580 .map(|e| {
581 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
582 diff.as_str()
583 })
584 .collect::<Vec<_>>()
585 .join("\n---\n")
586}
587
588fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
589 events
590 .iter()
591 .map(|e| {
592 let zeta_prompt::Event::BufferChange {
593 diff, predicted, ..
594 } = e.event.as_ref();
595 let prefix = if *predicted { "predicted" } else { "manual" };
596 format!("{}\n{}", prefix, diff)
597 })
598 .collect()
599}
600
601#[gpui::test]
602async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
603 let (ep_store, _requests) = init_test_with_fake_client(cx);
604 let fs = FakeFs::new(cx.executor());
605 fs.insert_tree(
606 "/root",
607 json!({
608 "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
609 }),
610 )
611 .await;
612 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
613
614 let buffer = project
615 .update(cx, |project, cx| {
616 let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
617 project.open_buffer(path, cx)
618 })
619 .await
620 .unwrap();
621
622 ep_store.update(cx, |ep_store, cx| {
623 ep_store.register_buffer(&buffer, &project, cx);
624 });
625
626 // Case 1: Manual edits have `predicted` set to false.
627 buffer.update(cx, |buffer, cx| {
628 buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
629 });
630
631 let events = ep_store.update(cx, |ep_store, cx| {
632 ep_store.edit_history_for_project(&project, cx)
633 });
634
635 assert_eq!(
636 render_events_with_predicted(&events),
637 vec![indoc! {"
638 manual
639 @@ -1,4 +1,4 @@
640 -line 0
641 +LINE ZERO
642 line 1
643 line 2
644 line 3
645 "}]
646 );
647
648 // Case 2: Multiple successive manual edits near each other are merged into one
649 // event with `predicted` set to false.
650 buffer.update(cx, |buffer, cx| {
651 let offset = Point::new(1, 0).to_offset(buffer);
652 let end = Point::new(1, 6).to_offset(buffer);
653 buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
654 });
655
656 let events = ep_store.update(cx, |ep_store, cx| {
657 ep_store.edit_history_for_project(&project, cx)
658 });
659 assert_eq!(
660 render_events_with_predicted(&events),
661 vec![indoc! {"
662 manual
663 @@ -1,5 +1,5 @@
664 -line 0
665 -line 1
666 +LINE ZERO
667 +LINE ONE
668 line 2
669 line 3
670 line 4
671 "}]
672 );
673
674 // Case 3: Accepted predictions have `predicted` set to true.
675 // Case 5: A manual edit that follows a predicted edit is not merged with the
676 // predicted edit, even if it is nearby.
677 ep_store.update(cx, |ep_store, cx| {
678 buffer.update(cx, |buffer, cx| {
679 let offset = Point::new(2, 0).to_offset(buffer);
680 let end = Point::new(2, 6).to_offset(buffer);
681 buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
682 });
683 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
684 });
685
686 let events = ep_store.update(cx, |ep_store, cx| {
687 ep_store.edit_history_for_project(&project, cx)
688 });
689 assert_eq!(
690 render_events_with_predicted(&events),
691 vec![
692 indoc! {"
693 manual
694 @@ -1,5 +1,5 @@
695 -line 0
696 -line 1
697 +LINE ZERO
698 +LINE ONE
699 line 2
700 line 3
701 line 4
702 "},
703 indoc! {"
704 predicted
705 @@ -1,6 +1,6 @@
706 LINE ZERO
707 LINE ONE
708 -line 2
709 +LINE TWO
710 line 3
711 line 4
712 line 5
713 "}
714 ]
715 );
716
717 // Case 4: Multiple successive accepted predictions near each other are merged
718 // into one event with `predicted` set to true.
719 ep_store.update(cx, |ep_store, cx| {
720 buffer.update(cx, |buffer, cx| {
721 let offset = Point::new(3, 0).to_offset(buffer);
722 let end = Point::new(3, 6).to_offset(buffer);
723 buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
724 });
725 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
726 });
727
728 let events = ep_store.update(cx, |ep_store, cx| {
729 ep_store.edit_history_for_project(&project, cx)
730 });
731 assert_eq!(
732 render_events_with_predicted(&events),
733 vec![
734 indoc! {"
735 manual
736 @@ -1,5 +1,5 @@
737 -line 0
738 -line 1
739 +LINE ZERO
740 +LINE ONE
741 line 2
742 line 3
743 line 4
744 "},
745 indoc! {"
746 predicted
747 @@ -1,7 +1,7 @@
748 LINE ZERO
749 LINE ONE
750 -line 2
751 -line 3
752 +LINE TWO
753 +LINE THREE
754 line 4
755 line 5
756 line 6
757 "}
758 ]
759 );
760
761 // Case 5 (continued): A manual edit that follows a predicted edit is not merged
762 // with the predicted edit, even if it is nearby.
763 buffer.update(cx, |buffer, cx| {
764 let offset = Point::new(4, 0).to_offset(buffer);
765 let end = Point::new(4, 6).to_offset(buffer);
766 buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
767 });
768
769 let events = ep_store.update(cx, |ep_store, cx| {
770 ep_store.edit_history_for_project(&project, cx)
771 });
772 assert_eq!(
773 render_events_with_predicted(&events),
774 vec![
775 indoc! {"
776 manual
777 @@ -1,5 +1,5 @@
778 -line 0
779 -line 1
780 +LINE ZERO
781 +LINE ONE
782 line 2
783 line 3
784 line 4
785 "},
786 indoc! {"
787 predicted
788 @@ -1,7 +1,7 @@
789 LINE ZERO
790 LINE ONE
791 -line 2
792 -line 3
793 +LINE TWO
794 +LINE THREE
795 line 4
796 line 5
797 line 6
798 "},
799 indoc! {"
800 manual
801 @@ -2,7 +2,7 @@
802 LINE ONE
803 LINE TWO
804 LINE THREE
805 -line 4
806 +LINE FOUR
807 line 5
808 line 6
809 line 7
810 "}
811 ]
812 );
813
814 // Case 6: If we then perform a manual edit at a *different* location (more than
815 // 8 lines away), then the edits at the prior location can be merged with each
816 // other, even if some are predicted and some are not. `predicted` means all
817 // constituent edits were predicted.
818 buffer.update(cx, |buffer, cx| {
819 let offset = Point::new(14, 0).to_offset(buffer);
820 let end = Point::new(14, 7).to_offset(buffer);
821 buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
822 });
823
824 let events = ep_store.update(cx, |ep_store, cx| {
825 ep_store.edit_history_for_project(&project, cx)
826 });
827 assert_eq!(
828 render_events_with_predicted(&events),
829 vec![
830 indoc! {"
831 manual
832 @@ -1,8 +1,8 @@
833 -line 0
834 -line 1
835 -line 2
836 -line 3
837 -line 4
838 +LINE ZERO
839 +LINE ONE
840 +LINE TWO
841 +LINE THREE
842 +LINE FOUR
843 line 5
844 line 6
845 line 7
846 "},
847 indoc! {"
848 manual
849 @@ -12,4 +12,4 @@
850 line 11
851 line 12
852 line 13
853 -line 14
854 +LINE FOURTEEN
855 "}
856 ]
857 );
858}
859
860#[gpui::test]
861async fn test_empty_prediction(cx: &mut TestAppContext) {
862 let (ep_store, mut requests) = init_test_with_fake_client(cx);
863 let fs = FakeFs::new(cx.executor());
864 fs.insert_tree(
865 "/root",
866 json!({
867 "foo.md": "Hello!\nHow\nBye\n"
868 }),
869 )
870 .await;
871 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
872
873 let buffer = project
874 .update(cx, |project, cx| {
875 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
876 project.open_buffer(path, cx)
877 })
878 .await
879 .unwrap();
880 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
881 let position = snapshot.anchor_before(language::Point::new(1, 3));
882
883 ep_store.update(cx, |ep_store, cx| {
884 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
885 });
886
887 let (request, respond_tx) = requests.predict.next().await.unwrap();
888 let response = model_response(&request, "");
889 let id = response.request_id.clone();
890 respond_tx.send(response).unwrap();
891
892 cx.run_until_parked();
893
894 ep_store.update(cx, |ep_store, cx| {
895 assert!(
896 ep_store
897 .prediction_at(&buffer, None, &project, cx)
898 .is_none()
899 );
900 });
901
902 // prediction is reported as rejected
903 let (reject_request, _) = requests.reject.next().await.unwrap();
904
905 assert_eq!(
906 &reject_request.rejections,
907 &[EditPredictionRejection {
908 request_id: id,
909 reason: EditPredictionRejectReason::Empty,
910 was_shown: false,
911 model_version: None,
912 }]
913 );
914}
915
916#[gpui::test]
917async fn test_interpolated_empty(cx: &mut TestAppContext) {
918 let (ep_store, mut requests) = init_test_with_fake_client(cx);
919 let fs = FakeFs::new(cx.executor());
920 fs.insert_tree(
921 "/root",
922 json!({
923 "foo.md": "Hello!\nHow\nBye\n"
924 }),
925 )
926 .await;
927 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
928
929 let buffer = project
930 .update(cx, |project, cx| {
931 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
932 project.open_buffer(path, cx)
933 })
934 .await
935 .unwrap();
936 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
937 let position = snapshot.anchor_before(language::Point::new(1, 3));
938
939 ep_store.update(cx, |ep_store, cx| {
940 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
941 });
942
943 let (request, respond_tx) = requests.predict.next().await.unwrap();
944
945 buffer.update(cx, |buffer, cx| {
946 buffer.set_text("Hello!\nHow are you?\nBye", cx);
947 });
948
949 let response = model_response(&request, SIMPLE_DIFF);
950 let id = response.request_id.clone();
951 respond_tx.send(response).unwrap();
952
953 cx.run_until_parked();
954
955 ep_store.update(cx, |ep_store, cx| {
956 assert!(
957 ep_store
958 .prediction_at(&buffer, None, &project, cx)
959 .is_none()
960 );
961 });
962
963 // prediction is reported as rejected
964 let (reject_request, _) = requests.reject.next().await.unwrap();
965
966 assert_eq!(
967 &reject_request.rejections,
968 &[EditPredictionRejection {
969 request_id: id,
970 reason: EditPredictionRejectReason::InterpolatedEmpty,
971 was_shown: false,
972 model_version: None,
973 }]
974 );
975}
976
977const SIMPLE_DIFF: &str = indoc! { r"
978 --- a/root/foo.md
979 +++ b/root/foo.md
980 @@ ... @@
981 Hello!
982 -How
983 +How are you?
984 Bye
985"};
986
987#[gpui::test]
988async fn test_replace_current(cx: &mut TestAppContext) {
989 let (ep_store, mut requests) = init_test_with_fake_client(cx);
990 let fs = FakeFs::new(cx.executor());
991 fs.insert_tree(
992 "/root",
993 json!({
994 "foo.md": "Hello!\nHow\nBye\n"
995 }),
996 )
997 .await;
998 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
999
1000 let buffer = project
1001 .update(cx, |project, cx| {
1002 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1003 project.open_buffer(path, cx)
1004 })
1005 .await
1006 .unwrap();
1007 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1008 let position = snapshot.anchor_before(language::Point::new(1, 3));
1009
1010 ep_store.update(cx, |ep_store, cx| {
1011 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1012 });
1013
1014 let (request, respond_tx) = requests.predict.next().await.unwrap();
1015 let first_response = model_response(&request, SIMPLE_DIFF);
1016 let first_id = first_response.request_id.clone();
1017 respond_tx.send(first_response).unwrap();
1018
1019 cx.run_until_parked();
1020
1021 ep_store.update(cx, |ep_store, cx| {
1022 assert_eq!(
1023 ep_store
1024 .prediction_at(&buffer, None, &project, cx)
1025 .unwrap()
1026 .id
1027 .0,
1028 first_id
1029 );
1030 });
1031
1032 // a second request is triggered
1033 ep_store.update(cx, |ep_store, cx| {
1034 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1035 });
1036
1037 let (request, respond_tx) = requests.predict.next().await.unwrap();
1038 let second_response = model_response(&request, SIMPLE_DIFF);
1039 let second_id = second_response.request_id.clone();
1040 respond_tx.send(second_response).unwrap();
1041
1042 cx.run_until_parked();
1043
1044 ep_store.update(cx, |ep_store, cx| {
1045 // second replaces first
1046 assert_eq!(
1047 ep_store
1048 .prediction_at(&buffer, None, &project, cx)
1049 .unwrap()
1050 .id
1051 .0,
1052 second_id
1053 );
1054 });
1055
1056 // first is reported as replaced
1057 let (reject_request, _) = requests.reject.next().await.unwrap();
1058
1059 assert_eq!(
1060 &reject_request.rejections,
1061 &[EditPredictionRejection {
1062 request_id: first_id,
1063 reason: EditPredictionRejectReason::Replaced,
1064 was_shown: false,
1065 model_version: None,
1066 }]
1067 );
1068}
1069
1070#[gpui::test]
1071async fn test_current_preferred(cx: &mut TestAppContext) {
1072 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1073 let fs = FakeFs::new(cx.executor());
1074 fs.insert_tree(
1075 "/root",
1076 json!({
1077 "foo.md": "Hello!\nHow\nBye\n"
1078 }),
1079 )
1080 .await;
1081 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1082
1083 let buffer = project
1084 .update(cx, |project, cx| {
1085 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1086 project.open_buffer(path, cx)
1087 })
1088 .await
1089 .unwrap();
1090 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1091 let position = snapshot.anchor_before(language::Point::new(1, 3));
1092
1093 ep_store.update(cx, |ep_store, cx| {
1094 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1095 });
1096
1097 let (request, respond_tx) = requests.predict.next().await.unwrap();
1098 let first_response = model_response(&request, SIMPLE_DIFF);
1099 let first_id = first_response.request_id.clone();
1100 respond_tx.send(first_response).unwrap();
1101
1102 cx.run_until_parked();
1103
1104 ep_store.update(cx, |ep_store, cx| {
1105 assert_eq!(
1106 ep_store
1107 .prediction_at(&buffer, None, &project, cx)
1108 .unwrap()
1109 .id
1110 .0,
1111 first_id
1112 );
1113 });
1114
1115 // a second request is triggered
1116 ep_store.update(cx, |ep_store, cx| {
1117 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1118 });
1119
1120 let (request, respond_tx) = requests.predict.next().await.unwrap();
1121 // worse than current prediction
1122 let second_response = model_response(
1123 &request,
1124 indoc! { r"
1125 --- a/root/foo.md
1126 +++ b/root/foo.md
1127 @@ ... @@
1128 Hello!
1129 -How
1130 +How are
1131 Bye
1132 "},
1133 );
1134 let second_id = second_response.request_id.clone();
1135 respond_tx.send(second_response).unwrap();
1136
1137 cx.run_until_parked();
1138
1139 ep_store.update(cx, |ep_store, cx| {
1140 // first is preferred over second
1141 assert_eq!(
1142 ep_store
1143 .prediction_at(&buffer, None, &project, cx)
1144 .unwrap()
1145 .id
1146 .0,
1147 first_id
1148 );
1149 });
1150
1151 // second is reported as rejected
1152 let (reject_request, _) = requests.reject.next().await.unwrap();
1153
1154 assert_eq!(
1155 &reject_request.rejections,
1156 &[EditPredictionRejection {
1157 request_id: second_id,
1158 reason: EditPredictionRejectReason::CurrentPreferred,
1159 was_shown: false,
1160 model_version: None,
1161 }]
1162 );
1163}
1164
1165#[gpui::test]
1166async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1167 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1168 let fs = FakeFs::new(cx.executor());
1169 fs.insert_tree(
1170 "/root",
1171 json!({
1172 "foo.md": "Hello!\nHow\nBye\n"
1173 }),
1174 )
1175 .await;
1176 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1177
1178 let buffer = project
1179 .update(cx, |project, cx| {
1180 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1181 project.open_buffer(path, cx)
1182 })
1183 .await
1184 .unwrap();
1185 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1186 let position = snapshot.anchor_before(language::Point::new(1, 3));
1187
1188 // start two refresh tasks
1189 ep_store.update(cx, |ep_store, cx| {
1190 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1191 });
1192
1193 let (request1, respond_first) = requests.predict.next().await.unwrap();
1194
1195 ep_store.update(cx, |ep_store, cx| {
1196 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1197 });
1198
1199 let (request, respond_second) = requests.predict.next().await.unwrap();
1200
1201 // wait for throttle
1202 cx.run_until_parked();
1203
1204 // second responds first
1205 let second_response = model_response(&request, SIMPLE_DIFF);
1206 let second_id = second_response.request_id.clone();
1207 respond_second.send(second_response).unwrap();
1208
1209 cx.run_until_parked();
1210
1211 ep_store.update(cx, |ep_store, cx| {
1212 // current prediction is second
1213 assert_eq!(
1214 ep_store
1215 .prediction_at(&buffer, None, &project, cx)
1216 .unwrap()
1217 .id
1218 .0,
1219 second_id
1220 );
1221 });
1222
1223 let first_response = model_response(&request1, SIMPLE_DIFF);
1224 let first_id = first_response.request_id.clone();
1225 respond_first.send(first_response).unwrap();
1226
1227 cx.run_until_parked();
1228
1229 ep_store.update(cx, |ep_store, cx| {
1230 // current prediction is still second, since first was cancelled
1231 assert_eq!(
1232 ep_store
1233 .prediction_at(&buffer, None, &project, cx)
1234 .unwrap()
1235 .id
1236 .0,
1237 second_id
1238 );
1239 });
1240
1241 // first is reported as rejected
1242 let (reject_request, _) = requests.reject.next().await.unwrap();
1243
1244 cx.run_until_parked();
1245
1246 assert_eq!(
1247 &reject_request.rejections,
1248 &[EditPredictionRejection {
1249 request_id: first_id,
1250 reason: EditPredictionRejectReason::Canceled,
1251 was_shown: false,
1252 model_version: None,
1253 }]
1254 );
1255}
1256
1257#[gpui::test]
1258async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1259 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1260 let fs = FakeFs::new(cx.executor());
1261 fs.insert_tree(
1262 "/root",
1263 json!({
1264 "foo.md": "Hello!\nHow\nBye\n"
1265 }),
1266 )
1267 .await;
1268 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1269
1270 let buffer = project
1271 .update(cx, |project, cx| {
1272 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1273 project.open_buffer(path, cx)
1274 })
1275 .await
1276 .unwrap();
1277 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1278 let position = snapshot.anchor_before(language::Point::new(1, 3));
1279
1280 // start two refresh tasks
1281 ep_store.update(cx, |ep_store, cx| {
1282 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1283 });
1284
1285 let (request1, respond_first) = requests.predict.next().await.unwrap();
1286
1287 ep_store.update(cx, |ep_store, cx| {
1288 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1289 });
1290
1291 let (request2, respond_second) = requests.predict.next().await.unwrap();
1292
1293 // wait for throttle, so requests are sent
1294 cx.run_until_parked();
1295
1296 ep_store.update(cx, |ep_store, cx| {
1297 // start a third request
1298 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1299
1300 // 2 are pending, so 2nd is cancelled
1301 assert_eq!(
1302 ep_store
1303 .get_or_init_project(&project, cx)
1304 .cancelled_predictions
1305 .iter()
1306 .copied()
1307 .collect::<Vec<_>>(),
1308 [1]
1309 );
1310 });
1311
1312 // wait for throttle
1313 cx.run_until_parked();
1314
1315 let (request3, respond_third) = requests.predict.next().await.unwrap();
1316
1317 let first_response = model_response(&request1, SIMPLE_DIFF);
1318 let first_id = first_response.request_id.clone();
1319 respond_first.send(first_response).unwrap();
1320
1321 cx.run_until_parked();
1322
1323 ep_store.update(cx, |ep_store, cx| {
1324 // current prediction is first
1325 assert_eq!(
1326 ep_store
1327 .prediction_at(&buffer, None, &project, cx)
1328 .unwrap()
1329 .id
1330 .0,
1331 first_id
1332 );
1333 });
1334
1335 let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1336 let cancelled_id = cancelled_response.request_id.clone();
1337 respond_second.send(cancelled_response).unwrap();
1338
1339 cx.run_until_parked();
1340
1341 ep_store.update(cx, |ep_store, cx| {
1342 // current prediction is still first, since second was cancelled
1343 assert_eq!(
1344 ep_store
1345 .prediction_at(&buffer, None, &project, cx)
1346 .unwrap()
1347 .id
1348 .0,
1349 first_id
1350 );
1351 });
1352
1353 let third_response = model_response(&request3, SIMPLE_DIFF);
1354 let third_response_id = third_response.request_id.clone();
1355 respond_third.send(third_response).unwrap();
1356
1357 cx.run_until_parked();
1358
1359 ep_store.update(cx, |ep_store, cx| {
1360 // third completes and replaces first
1361 assert_eq!(
1362 ep_store
1363 .prediction_at(&buffer, None, &project, cx)
1364 .unwrap()
1365 .id
1366 .0,
1367 third_response_id
1368 );
1369 });
1370
1371 // second is reported as rejected
1372 let (reject_request, _) = requests.reject.next().await.unwrap();
1373
1374 cx.run_until_parked();
1375
1376 assert_eq!(
1377 &reject_request.rejections,
1378 &[
1379 EditPredictionRejection {
1380 request_id: cancelled_id,
1381 reason: EditPredictionRejectReason::Canceled,
1382 was_shown: false,
1383 model_version: None,
1384 },
1385 EditPredictionRejection {
1386 request_id: first_id,
1387 reason: EditPredictionRejectReason::Replaced,
1388 was_shown: false,
1389 model_version: None,
1390 }
1391 ]
1392 );
1393}
1394
1395#[gpui::test]
1396async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
1397 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1398
1399 let fs = FakeFs::new(cx.executor());
1400 fs.insert_tree(
1401 "/root",
1402 json!({
1403 "foo.md": "Hello!\nHow\nBye\n",
1404 "bar.md": "Hola!\nComo\nAdios\n"
1405 }),
1406 )
1407 .await;
1408 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1409
1410 let buffer = project
1411 .update(cx, |project, cx| {
1412 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1413 project.set_active_path(Some(path.clone()), cx);
1414 project.open_buffer(path, cx)
1415 })
1416 .await
1417 .unwrap();
1418 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1419 let position = snapshot.anchor_before(language::Point::new(1, 3));
1420
1421 ep_store.update(cx, |ep_store, cx| {
1422 ep_store.register_project(&project, cx);
1423 ep_store.register_buffer(&buffer, &project, cx);
1424 });
1425
1426 // First edit request - no prior edit, so not throttled.
1427 ep_store.update(cx, |ep_store, cx| {
1428 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1429 });
1430 let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
1431 edit_response_tx.send(empty_response()).unwrap();
1432 cx.run_until_parked();
1433
1434 let diagnostic = lsp::Diagnostic {
1435 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1436 severity: Some(lsp::DiagnosticSeverity::ERROR),
1437 message: "Sentence is incomplete".to_string(),
1438 ..Default::default()
1439 };
1440
1441 // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
1442 project.update(cx, |project, cx| {
1443 project.lsp_store().update(cx, |lsp_store, cx| {
1444 lsp_store
1445 .update_diagnostics(
1446 LanguageServerId(0),
1447 lsp::PublishDiagnosticsParams {
1448 uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
1449 diagnostics: vec![diagnostic],
1450 version: None,
1451 },
1452 None,
1453 language::DiagnosticSourceKind::Pushed,
1454 &[],
1455 cx,
1456 )
1457 .unwrap();
1458 });
1459 });
1460 let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
1461 jump_response_tx.send(empty_response()).unwrap();
1462 cx.run_until_parked();
1463
1464 // Second edit request - should be throttled by the first edit.
1465 ep_store.update(cx, |ep_store, cx| {
1466 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1467 });
1468 assert_no_predict_request_ready(&mut requests.predict);
1469
1470 // Second jump request - should be throttled by the first jump.
1471 ep_store.update(cx, |ep_store, cx| {
1472 ep_store.refresh_prediction_from_diagnostics(
1473 project.clone(),
1474 DiagnosticSearchScope::Global,
1475 cx,
1476 );
1477 });
1478 assert_no_predict_request_ready(&mut requests.predict);
1479
1480 // Wait for both throttles to expire.
1481 cx.background_executor
1482 .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
1483 cx.background_executor.run_until_parked();
1484 cx.run_until_parked();
1485
1486 // Both requests should now go through.
1487 let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
1488 response_tx_1.send(empty_response()).unwrap();
1489 cx.run_until_parked();
1490
1491 let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
1492 response_tx_2.send(empty_response()).unwrap();
1493 cx.run_until_parked();
1494}
1495
1496#[gpui::test]
1497async fn test_same_frame_duplicate_requests_deduplicated(cx: &mut TestAppContext) {
1498 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1499 let fs = FakeFs::new(cx.executor());
1500 fs.insert_tree(
1501 "/root",
1502 json!({
1503 "foo.md": "Hello!\nHow\nBye\n"
1504 }),
1505 )
1506 .await;
1507 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1508
1509 let buffer = project
1510 .update(cx, |project, cx| {
1511 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1512 project.open_buffer(path, cx)
1513 })
1514 .await
1515 .unwrap();
1516 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1517 let position = snapshot.anchor_before(language::Point::new(1, 3));
1518
1519 // Enqueue two refresh calls in the same synchronous frame (no yielding).
1520 // Both `cx.spawn` tasks are created before either executes, so they both
1521 // capture the same `proceed_count_at_enqueue`. Only the first task should
1522 // pass the deduplication gate; the second should be skipped.
1523 ep_store.update(cx, |ep_store, cx| {
1524 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1525 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1526 });
1527
1528 // Let both spawned tasks run to completion (including any throttle waits).
1529 cx.run_until_parked();
1530
1531 // Exactly one prediction request should have been sent.
1532 let (request, respond_tx) = requests.predict.next().await.unwrap();
1533 respond_tx
1534 .send(model_response(&request, SIMPLE_DIFF))
1535 .unwrap();
1536 cx.run_until_parked();
1537
1538 // No second request should be pending.
1539 assert_no_predict_request_ready(&mut requests.predict);
1540}
1541
1542#[gpui::test]
1543async fn test_rejections_flushing(cx: &mut TestAppContext) {
1544 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1545
1546 ep_store.update(cx, |ep_store, cx| {
1547 ep_store.reject_prediction(
1548 EditPredictionId("test-1".into()),
1549 EditPredictionRejectReason::Discarded,
1550 false,
1551 None,
1552 cx,
1553 );
1554 ep_store.reject_prediction(
1555 EditPredictionId("test-2".into()),
1556 EditPredictionRejectReason::Canceled,
1557 true,
1558 None,
1559 cx,
1560 );
1561 });
1562
1563 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1564 cx.run_until_parked();
1565
1566 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1567 respond_tx.send(()).unwrap();
1568
1569 // batched
1570 assert_eq!(reject_request.rejections.len(), 2);
1571 assert_eq!(
1572 reject_request.rejections[0],
1573 EditPredictionRejection {
1574 request_id: "test-1".to_string(),
1575 reason: EditPredictionRejectReason::Discarded,
1576 was_shown: false,
1577 model_version: None,
1578 }
1579 );
1580 assert_eq!(
1581 reject_request.rejections[1],
1582 EditPredictionRejection {
1583 request_id: "test-2".to_string(),
1584 reason: EditPredictionRejectReason::Canceled,
1585 was_shown: true,
1586 model_version: None,
1587 }
1588 );
1589
1590 // Reaching batch size limit sends without debounce
1591 ep_store.update(cx, |ep_store, cx| {
1592 for i in 0..70 {
1593 ep_store.reject_prediction(
1594 EditPredictionId(format!("batch-{}", i).into()),
1595 EditPredictionRejectReason::Discarded,
1596 false,
1597 None,
1598 cx,
1599 );
1600 }
1601 });
1602
1603 // First MAX/2 items are sent immediately
1604 cx.run_until_parked();
1605 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1606 respond_tx.send(()).unwrap();
1607
1608 assert_eq!(reject_request.rejections.len(), 50);
1609 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1610 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1611
1612 // Remaining items are debounced with the next batch
1613 cx.executor().advance_clock(Duration::from_secs(15));
1614 cx.run_until_parked();
1615
1616 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1617 respond_tx.send(()).unwrap();
1618
1619 assert_eq!(reject_request.rejections.len(), 20);
1620 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1621 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1622
1623 // Request failure
1624 ep_store.update(cx, |ep_store, cx| {
1625 ep_store.reject_prediction(
1626 EditPredictionId("retry-1".into()),
1627 EditPredictionRejectReason::Discarded,
1628 false,
1629 None,
1630 cx,
1631 );
1632 });
1633
1634 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1635 cx.run_until_parked();
1636
1637 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1638 assert_eq!(reject_request.rejections.len(), 1);
1639 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1640 // Simulate failure
1641 drop(_respond_tx);
1642
1643 // Add another rejection
1644 ep_store.update(cx, |ep_store, cx| {
1645 ep_store.reject_prediction(
1646 EditPredictionId("retry-2".into()),
1647 EditPredictionRejectReason::Discarded,
1648 false,
1649 None,
1650 cx,
1651 );
1652 });
1653
1654 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1655 cx.run_until_parked();
1656
1657 // Retry should include both the failed item and the new one
1658 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1659 respond_tx.send(()).unwrap();
1660
1661 assert_eq!(reject_request.rejections.len(), 2);
1662 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1663 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1664}
1665
1666#[gpui::test]
1667fn test_active_buffer_diagnostics_fetching(cx: &mut TestAppContext) {
1668 let diagnostic_marker: TextRangeMarker = ('«', '»').into();
1669 let search_range_marker: TextRangeMarker = ('[', ']').into();
1670
1671 let (text, mut ranges) = marked_text_ranges_by(
1672 indoc! {r#"
1673 fn alpha() {
1674 let «first_value» = 1;
1675 }
1676
1677 [fn beta() {
1678 let «second_value» = 2;
1679 let third_value = second_value + missing_symbol;
1680 }ˇ]
1681
1682 fn gamma() {
1683 let «fourth_value» = missing_other_symbol;
1684 }
1685 "#},
1686 vec![diagnostic_marker.clone(), search_range_marker.clone()],
1687 );
1688
1689 let diagnostic_ranges = ranges.remove(&diagnostic_marker).unwrap_or_default();
1690 let search_ranges = ranges.remove(&search_range_marker).unwrap_or_default();
1691
1692 let buffer = cx.new(|cx| Buffer::local(&text, cx));
1693
1694 buffer.update(cx, |buffer, cx| {
1695 let snapshot = buffer.snapshot();
1696 let diagnostics = DiagnosticSet::new(
1697 diagnostic_ranges
1698 .iter()
1699 .enumerate()
1700 .map(|(index, range)| DiagnosticEntry {
1701 range: snapshot.offset_to_point_utf16(range.start)
1702 ..snapshot.offset_to_point_utf16(range.end),
1703 diagnostic: Diagnostic {
1704 severity: match index {
1705 0 => DiagnosticSeverity::WARNING,
1706 1 => DiagnosticSeverity::ERROR,
1707 _ => DiagnosticSeverity::HINT,
1708 },
1709 message: match index {
1710 0 => "first warning".to_string(),
1711 1 => "second error".to_string(),
1712 _ => "third hint".to_string(),
1713 },
1714 group_id: index + 1,
1715 is_primary: true,
1716 source_kind: language::DiagnosticSourceKind::Pushed,
1717 ..Diagnostic::default()
1718 },
1719 }),
1720 &snapshot,
1721 );
1722 buffer.update_diagnostics(LanguageServerId(0), diagnostics, cx);
1723 });
1724
1725 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1726 let search_range = snapshot.offset_to_point(search_ranges[0].start)
1727 ..snapshot.offset_to_point(search_ranges[0].end);
1728
1729 let active_buffer_diagnostics = zeta::active_buffer_diagnostics(&snapshot, search_range, 100);
1730
1731 assert_eq!(
1732 active_buffer_diagnostics,
1733 vec![zeta_prompt::ActiveBufferDiagnostic {
1734 severity: Some(1),
1735 message: "second error".to_string(),
1736 snippet: text,
1737 snippet_buffer_row_range: 5..5,
1738 diagnostic_range_in_snippet: 61..73,
1739 }]
1740 );
1741
1742 let buffer = cx.new(|cx| {
1743 Buffer::local(
1744 indoc! {"
1745 one
1746 two
1747 three
1748 four
1749 five
1750 "},
1751 cx,
1752 )
1753 });
1754
1755 buffer.update(cx, |buffer, cx| {
1756 let snapshot = buffer.snapshot();
1757 let diagnostics = DiagnosticSet::new(
1758 vec![
1759 DiagnosticEntry {
1760 range: text::PointUtf16::new(0, 0)..text::PointUtf16::new(0, 3),
1761 diagnostic: Diagnostic {
1762 severity: DiagnosticSeverity::ERROR,
1763 message: "row zero".to_string(),
1764 group_id: 1,
1765 is_primary: true,
1766 source_kind: language::DiagnosticSourceKind::Pushed,
1767 ..Diagnostic::default()
1768 },
1769 },
1770 DiagnosticEntry {
1771 range: text::PointUtf16::new(2, 0)..text::PointUtf16::new(2, 5),
1772 diagnostic: Diagnostic {
1773 severity: DiagnosticSeverity::WARNING,
1774 message: "row two".to_string(),
1775 group_id: 2,
1776 is_primary: true,
1777 source_kind: language::DiagnosticSourceKind::Pushed,
1778 ..Diagnostic::default()
1779 },
1780 },
1781 DiagnosticEntry {
1782 range: text::PointUtf16::new(4, 0)..text::PointUtf16::new(4, 4),
1783 diagnostic: Diagnostic {
1784 severity: DiagnosticSeverity::INFORMATION,
1785 message: "row four".to_string(),
1786 group_id: 3,
1787 is_primary: true,
1788 source_kind: language::DiagnosticSourceKind::Pushed,
1789 ..Diagnostic::default()
1790 },
1791 },
1792 ],
1793 &snapshot,
1794 );
1795 buffer.update_diagnostics(LanguageServerId(0), diagnostics, cx);
1796 });
1797
1798 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1799
1800 let active_buffer_diagnostics =
1801 zeta::active_buffer_diagnostics(&snapshot, Point::new(2, 0)..Point::new(4, 0), 100);
1802
1803 assert_eq!(
1804 active_buffer_diagnostics
1805 .iter()
1806 .map(|diagnostic| (
1807 diagnostic.severity,
1808 diagnostic.message.clone(),
1809 diagnostic.snippet.clone(),
1810 diagnostic.snippet_buffer_row_range.clone(),
1811 diagnostic.diagnostic_range_in_snippet.clone(),
1812 ))
1813 .collect::<Vec<_>>(),
1814 vec![
1815 (
1816 Some(2),
1817 "row two".to_string(),
1818 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1819 2..2,
1820 8..13,
1821 ),
1822 (
1823 Some(3),
1824 "row four".to_string(),
1825 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1826 4..4,
1827 19..23,
1828 ),
1829 ]
1830 );
1831}
1832
1833// Generate a model response that would apply the given diff to the active file.
1834fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1835 let editable_range =
1836 zeta_prompt::excerpt_range_for_format(Default::default(), &request.input.excerpt_ranges).1;
1837 let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string();
1838 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1839
1840 PredictEditsV3Response {
1841 request_id: Uuid::new_v4().to_string(),
1842 editable_range,
1843 output: new_excerpt,
1844 model_version: None,
1845 }
1846}
1847
1848fn empty_response() -> PredictEditsV3Response {
1849 PredictEditsV3Response {
1850 request_id: Uuid::new_v4().to_string(),
1851 editable_range: 0..0,
1852 output: String::new(),
1853 model_version: None,
1854 }
1855}
1856
1857fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1858 zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1859}
1860
1861fn assert_no_predict_request_ready(
1862 requests: &mut mpsc::UnboundedReceiver<(
1863 PredictEditsV3Request,
1864 oneshot::Sender<PredictEditsV3Response>,
1865 )>,
1866) {
1867 if requests.next().now_or_never().flatten().is_some() {
1868 panic!("Unexpected prediction request while throttled.");
1869 }
1870}
1871
1872struct RequestChannels {
1873 predict: mpsc::UnboundedReceiver<(
1874 PredictEditsV3Request,
1875 oneshot::Sender<PredictEditsV3Response>,
1876 )>,
1877 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1878}
1879
1880fn init_test_with_fake_client(
1881 cx: &mut TestAppContext,
1882) -> (Entity<EditPredictionStore>, RequestChannels) {
1883 cx.update(move |cx| {
1884 let settings_store = SettingsStore::test(cx);
1885 cx.set_global(settings_store);
1886 zlog::init_test();
1887
1888 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1889 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1890
1891 let http_client = FakeHttpClient::create({
1892 move |req| {
1893 let uri = req.uri().path().to_string();
1894 let mut body = req.into_body();
1895 let predict_req_tx = predict_req_tx.clone();
1896 let reject_req_tx = reject_req_tx.clone();
1897 async move {
1898 let resp = match uri.as_str() {
1899 "/client/llm_tokens" => serde_json::to_string(&json!({
1900 "token": "test"
1901 }))
1902 .unwrap(),
1903 "/predict_edits/v3" => {
1904 let mut buf = Vec::new();
1905 body.read_to_end(&mut buf).await.ok();
1906 let decompressed = zstd::decode_all(&buf[..]).unwrap();
1907 let req = serde_json::from_slice(&decompressed).unwrap();
1908
1909 let (res_tx, res_rx) = oneshot::channel();
1910 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1911 serde_json::to_string(&res_rx.await?).unwrap()
1912 }
1913 "/predict_edits/reject" => {
1914 let mut buf = Vec::new();
1915 body.read_to_end(&mut buf).await.ok();
1916 let req = serde_json::from_slice(&buf).unwrap();
1917
1918 let (res_tx, res_rx) = oneshot::channel();
1919 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1920 serde_json::to_string(&res_rx.await?).unwrap()
1921 }
1922 _ => {
1923 panic!("Unexpected path: {}", uri)
1924 }
1925 };
1926
1927 Ok(Response::builder().body(resp.into()).unwrap())
1928 }
1929 }
1930 });
1931
1932 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1933 client.cloud_client().set_credentials(1, "test".into());
1934
1935 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1936 language_model::init(user_store.clone(), client.clone(), cx);
1937 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1938
1939 (
1940 ep_store,
1941 RequestChannels {
1942 predict: predict_req_rx,
1943 reject: reject_req_rx,
1944 },
1945 )
1946 })
1947}
1948
1949#[gpui::test]
1950async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1951 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1952 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1953 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1954 });
1955
1956 let edit_preview = cx
1957 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1958 .await;
1959
1960 let prediction = EditPrediction {
1961 edits,
1962 cursor_position: None,
1963 edit_preview,
1964 buffer: buffer.clone(),
1965 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1966 id: EditPredictionId("the-id".into()),
1967 inputs: ZetaPromptInput {
1968 events: Default::default(),
1969 related_files: Default::default(),
1970 active_buffer_diagnostics: vec![],
1971 cursor_path: Path::new("").into(),
1972 cursor_excerpt: "".into(),
1973 cursor_offset_in_excerpt: 0,
1974 excerpt_start_row: None,
1975 excerpt_ranges: Default::default(),
1976 syntax_ranges: None,
1977 experiment: None,
1978 in_open_source_repo: false,
1979 can_collect_data: false,
1980 repo_url: None,
1981 },
1982 buffer_snapshotted_at: Instant::now(),
1983 response_received_at: Instant::now(),
1984 model_version: None,
1985 };
1986
1987 cx.update(|cx| {
1988 assert_eq!(
1989 from_completion_edits(
1990 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1991 &buffer,
1992 cx
1993 ),
1994 vec![(2..5, "REM".into()), (9..11, "".into())]
1995 );
1996
1997 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1998 assert_eq!(
1999 from_completion_edits(
2000 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2001 &buffer,
2002 cx
2003 ),
2004 vec![(2..2, "REM".into()), (6..8, "".into())]
2005 );
2006
2007 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2008 assert_eq!(
2009 from_completion_edits(
2010 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2011 &buffer,
2012 cx
2013 ),
2014 vec![(2..5, "REM".into()), (9..11, "".into())]
2015 );
2016
2017 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
2018 assert_eq!(
2019 from_completion_edits(
2020 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2021 &buffer,
2022 cx
2023 ),
2024 vec![(3..3, "EM".into()), (7..9, "".into())]
2025 );
2026
2027 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
2028 assert_eq!(
2029 from_completion_edits(
2030 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2031 &buffer,
2032 cx
2033 ),
2034 vec![(4..4, "M".into()), (8..10, "".into())]
2035 );
2036
2037 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2038 assert_eq!(
2039 from_completion_edits(
2040 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2041 &buffer,
2042 cx
2043 ),
2044 vec![(9..11, "".into())]
2045 );
2046
2047 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2048 assert_eq!(
2049 from_completion_edits(
2050 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2051 &buffer,
2052 cx
2053 ),
2054 vec![(4..4, "M".into()), (8..10, "".into())]
2055 );
2056
2057 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2058 assert_eq!(
2059 from_completion_edits(
2060 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2061 &buffer,
2062 cx
2063 ),
2064 vec![(4..4, "M".into())]
2065 );
2066
2067 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2068 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
2069 })
2070}
2071
2072#[gpui::test]
2073async fn test_clean_up_diff(cx: &mut TestAppContext) {
2074 init_test(cx);
2075
2076 assert_eq!(
2077 apply_edit_prediction(
2078 indoc! {"
2079 fn main() {
2080 let word_1 = \"lorem\";
2081 let range = word.len()..word.len();
2082 }
2083 "},
2084 indoc! {"
2085 fn main() {
2086 let word_1 = \"lorem\";
2087 let range = word_1.len()..word_1.len();
2088 }
2089 "},
2090 cx,
2091 )
2092 .await,
2093 indoc! {"
2094 fn main() {
2095 let word_1 = \"lorem\";
2096 let range = word_1.len()..word_1.len();
2097 }
2098 "},
2099 );
2100
2101 assert_eq!(
2102 apply_edit_prediction(
2103 indoc! {"
2104 fn main() {
2105 let story = \"the quick\"
2106 }
2107 "},
2108 indoc! {"
2109 fn main() {
2110 let story = \"the quick brown fox jumps over the lazy dog\";
2111 }
2112 "},
2113 cx,
2114 )
2115 .await,
2116 indoc! {"
2117 fn main() {
2118 let story = \"the quick brown fox jumps over the lazy dog\";
2119 }
2120 "},
2121 );
2122}
2123
2124#[gpui::test]
2125async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2126 init_test(cx);
2127
2128 let buffer_content = "lorem\n";
2129 let completion_response = "lorem\nipsum\n";
2130
2131 assert_eq!(
2132 apply_edit_prediction(buffer_content, completion_response, cx).await,
2133 "lorem\nipsum\n"
2134 );
2135}
2136
2137#[gpui::test]
2138async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
2139 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
2140 // When the buffer ends without a trailing newline, but the model returns output
2141 // with a trailing newline, zeta2 should normalize both sides before diffing
2142 // so no spurious newline is inserted.
2143 let (ep_store, mut requests) = init_test_with_fake_client(cx);
2144 let fs = FakeFs::new(cx.executor());
2145
2146 // Single line buffer with no trailing newline
2147 fs.insert_tree(
2148 "/root",
2149 json!({
2150 "foo.txt": "hello"
2151 }),
2152 )
2153 .await;
2154 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2155
2156 let buffer = project
2157 .update(cx, |project, cx| {
2158 let path = project
2159 .find_project_path(path!("root/foo.txt"), cx)
2160 .unwrap();
2161 project.open_buffer(path, cx)
2162 })
2163 .await
2164 .unwrap();
2165
2166 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2167 let position = snapshot.anchor_before(language::Point::new(0, 5));
2168
2169 ep_store.update(cx, |ep_store, cx| {
2170 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2171 });
2172
2173 let (request, respond_tx) = requests.predict.next().await.unwrap();
2174
2175 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
2176 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
2177 let excerpt_length = request.input.cursor_excerpt.len();
2178 let response = PredictEditsV3Response {
2179 request_id: Uuid::new_v4().to_string(),
2180 output: "hello world\n".to_string(),
2181 editable_range: 0..excerpt_length,
2182 model_version: None,
2183 };
2184 respond_tx.send(response).unwrap();
2185
2186 cx.run_until_parked();
2187
2188 // The prediction should insert " world" without adding a newline
2189 ep_store.update(cx, |ep_store, cx| {
2190 let prediction = ep_store
2191 .prediction_at(&buffer, None, &project, cx)
2192 .expect("should have prediction");
2193 let edits: Vec<_> = prediction
2194 .edits
2195 .iter()
2196 .map(|(range, text)| {
2197 let snapshot = buffer.read(cx).snapshot();
2198 (range.to_offset(&snapshot), text.clone())
2199 })
2200 .collect();
2201 assert_eq!(edits, vec![(5..5, " world".into())]);
2202 });
2203}
2204
2205fn init_test(cx: &mut TestAppContext) {
2206 cx.update(|cx| {
2207 let settings_store = SettingsStore::test(cx);
2208 cx.set_global(settings_store);
2209 });
2210}
2211
2212async fn apply_edit_prediction(
2213 buffer_content: &str,
2214 completion_response: &str,
2215 cx: &mut TestAppContext,
2216) -> String {
2217 let fs = project::FakeFs::new(cx.executor());
2218 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2219 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2220 let (ep_store, response) = make_test_ep_store(&project, cx).await;
2221 *response.lock() = completion_response.to_string();
2222 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2223 buffer.update(cx, |buffer, cx| {
2224 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2225 });
2226 buffer.read_with(cx, |buffer, _| buffer.text())
2227}
2228
2229async fn run_edit_prediction(
2230 buffer: &Entity<Buffer>,
2231 project: &Entity<Project>,
2232 ep_store: &Entity<EditPredictionStore>,
2233 cx: &mut TestAppContext,
2234) -> EditPrediction {
2235 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2236 ep_store.update(cx, |ep_store, cx| {
2237 ep_store.register_buffer(buffer, &project, cx)
2238 });
2239 cx.background_executor.run_until_parked();
2240 let prediction_task = ep_store.update(cx, |ep_store, cx| {
2241 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2242 });
2243 prediction_task.await.unwrap().unwrap().prediction.unwrap()
2244}
2245
2246async fn make_test_ep_store(
2247 project: &Entity<Project>,
2248 cx: &mut TestAppContext,
2249) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
2250 let default_response = "hello world\n".to_string();
2251 let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
2252 let http_client = FakeHttpClient::create({
2253 let completion_response = completion_response.clone();
2254 let mut next_request_id = 0;
2255 move |req| {
2256 let completion_response = completion_response.clone();
2257 let method = req.method().clone();
2258 let uri = req.uri().path().to_string();
2259 let mut body = req.into_body();
2260 async move {
2261 match (method, uri.as_str()) {
2262 (Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2263 .status(200)
2264 .body(
2265 serde_json::to_string(&CreateLlmTokenResponse {
2266 token: LlmToken("the-llm-token".to_string()),
2267 })
2268 .unwrap()
2269 .into(),
2270 )
2271 .unwrap()),
2272 (Method::POST, "/predict_edits/v3") => {
2273 let mut buf = Vec::new();
2274 body.read_to_end(&mut buf).await.ok();
2275 let decompressed = zstd::decode_all(&buf[..]).unwrap();
2276 let req: PredictEditsV3Request =
2277 serde_json::from_slice(&decompressed).unwrap();
2278
2279 next_request_id += 1;
2280 Ok(http_client::Response::builder()
2281 .status(200)
2282 .body(
2283 serde_json::to_string(&PredictEditsV3Response {
2284 request_id: format!("request-{next_request_id}"),
2285 editable_range: 0..req.input.cursor_excerpt.len(),
2286 output: completion_response.lock().clone(),
2287 model_version: None,
2288 })
2289 .unwrap()
2290 .into(),
2291 )
2292 .unwrap())
2293 }
2294 _ => Ok(http_client::Response::builder()
2295 .status(404)
2296 .body("Not Found".to_string().into())
2297 .unwrap()),
2298 }
2299 }
2300 }
2301 });
2302
2303 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2304 let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
2305 cx.update(|cx| {
2306 RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
2307 });
2308 let _server = FakeServer::for_client(42, &client, cx).await;
2309
2310 let ep_store = cx.new(|cx| {
2311 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2312 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2313
2314 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2315 for worktree in worktrees {
2316 let worktree_id = worktree.read(cx).id();
2317 ep_store
2318 .get_or_init_project(project, cx)
2319 .license_detection_watchers
2320 .entry(worktree_id)
2321 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2322 }
2323
2324 ep_store
2325 });
2326
2327 (ep_store, completion_response)
2328}
2329
2330fn to_completion_edits(
2331 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2332 buffer: &Entity<Buffer>,
2333 cx: &App,
2334) -> Vec<(Range<Anchor>, Arc<str>)> {
2335 let buffer = buffer.read(cx);
2336 iterator
2337 .into_iter()
2338 .map(|(range, text)| {
2339 (
2340 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2341 text,
2342 )
2343 })
2344 .collect()
2345}
2346
2347fn from_completion_edits(
2348 editor_edits: &[(Range<Anchor>, Arc<str>)],
2349 buffer: &Entity<Buffer>,
2350 cx: &App,
2351) -> Vec<(Range<usize>, Arc<str>)> {
2352 let buffer = buffer.read(cx);
2353 editor_edits
2354 .iter()
2355 .map(|(range, text)| {
2356 (
2357 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2358 text.clone(),
2359 )
2360 })
2361 .collect()
2362}
2363
2364#[gpui::test]
2365async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2366 init_test(cx);
2367
2368 let fs = FakeFs::new(cx.executor());
2369 fs.insert_tree(
2370 "/project",
2371 serde_json::json!({
2372 "main.rs": "fn main() {\n \n}\n"
2373 }),
2374 )
2375 .await;
2376
2377 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2378
2379 let http_client = FakeHttpClient::create(|_req| async move {
2380 Ok(gpui::http_client::Response::builder()
2381 .status(401)
2382 .body("Unauthorized".into())
2383 .unwrap())
2384 });
2385
2386 let client =
2387 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2388 let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
2389 cx.update(|cx| {
2390 language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
2391 });
2392
2393 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2394
2395 let buffer = project
2396 .update(cx, |project, cx| {
2397 let path = project
2398 .find_project_path(path!("/project/main.rs"), cx)
2399 .unwrap();
2400 project.open_buffer(path, cx)
2401 })
2402 .await
2403 .unwrap();
2404
2405 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2406 ep_store.update(cx, |ep_store, cx| {
2407 ep_store.register_buffer(&buffer, &project, cx)
2408 });
2409 cx.background_executor.run_until_parked();
2410
2411 let completion_task = ep_store.update(cx, |ep_store, cx| {
2412 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2413 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2414 });
2415
2416 let result = completion_task.await;
2417 assert!(
2418 result.is_err(),
2419 "Without authentication and without custom URL, prediction should fail"
2420 );
2421}
2422
2423#[gpui::test]
2424fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2425 let buffer = cx.new(|cx| {
2426 Buffer::local(
2427 indoc! {"
2428 zero
2429 one
2430 two
2431 three
2432 four
2433 five
2434 six
2435 seven
2436 eight
2437 nine
2438 ten
2439 eleven
2440 twelve
2441 thirteen
2442 fourteen
2443 fifteen
2444 sixteen
2445 seventeen
2446 eighteen
2447 nineteen
2448 twenty
2449 twenty-one
2450 twenty-two
2451 twenty-three
2452 twenty-four
2453 "},
2454 cx,
2455 )
2456 });
2457
2458 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2459
2460 buffer.update(cx, |buffer, cx| {
2461 let point = Point::new(12, 0);
2462 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2463 let point = Point::new(8, 0);
2464 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2465 });
2466
2467 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2468
2469 let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2470
2471 assert_eq!(
2472 diff,
2473 indoc! {"
2474 @@ -6,10 +6,12 @@
2475 five
2476 six
2477 seven
2478 +FIRST INSERTION
2479 eight
2480 nine
2481 ten
2482 eleven
2483 +SECOND INSERTION
2484 twelve
2485 thirteen
2486 fourteen
2487 "}
2488 );
2489}
2490
2491#[gpui::test]
2492async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
2493 fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
2494 let collab_replica = clock::ReplicaId::new(10);
2495 let anchor = buffer.read_with(cx, |buffer, _| {
2496 buffer.snapshot().anchor_before(Point::new(row, 0))
2497 });
2498 let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
2499 id: 1,
2500 start: anchor,
2501 end: anchor,
2502 reversed: false,
2503 goal: SelectionGoal::None,
2504 }]);
2505 buffer.update(cx, |buffer, cx| {
2506 buffer.apply_ops(
2507 [Operation::UpdateSelections {
2508 selections,
2509 lamport_timestamp: clock::Lamport {
2510 replica_id: collab_replica,
2511 value: 1,
2512 },
2513 line_mode: false,
2514 cursor_shape: CursorShape::Bar,
2515 }],
2516 cx,
2517 );
2518 });
2519 }
2520
2521 fn publish_diagnostics(
2522 uri_path: &'static str,
2523 rows: &[u32],
2524 project: &Entity<Project>,
2525 cx: &mut TestAppContext,
2526 ) {
2527 let diagnostics: Vec<_> = rows
2528 .iter()
2529 .map(|&row| lsp::Diagnostic {
2530 range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
2531 severity: Some(lsp::DiagnosticSeverity::ERROR),
2532 message: format!("error at row {row}"),
2533 ..Default::default()
2534 })
2535 .collect();
2536 project.update(cx, |project, cx| {
2537 project.lsp_store().update(cx, |lsp_store, cx| {
2538 lsp_store
2539 .update_diagnostics(
2540 LanguageServerId(0),
2541 lsp::PublishDiagnosticsParams {
2542 uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
2543 diagnostics,
2544 version: None,
2545 },
2546 None,
2547 language::DiagnosticSourceKind::Pushed,
2548 &[],
2549 cx,
2550 )
2551 .expect("failed to update diagnostics");
2552 });
2553 });
2554 }
2555
2556 init_test(cx);
2557
2558 let mut lines = String::new();
2559 for i in 0..60 {
2560 lines.push_str(&format!("line {i}\n"));
2561 }
2562
2563 let fs = FakeFs::new(cx.executor());
2564 fs.insert_tree(
2565 "/root",
2566 json!({
2567 "active.txt": lines,
2568 "collab_file.txt": "error here\nsecond line\n",
2569 "free_file.txt": "another error\nsecond line\n",
2570 }),
2571 )
2572 .await;
2573 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2574
2575 let active_buffer = project
2576 .update(cx, |project, cx| {
2577 let path = project
2578 .find_project_path(path!("/root/active.txt"), cx)
2579 .expect("active.txt not found");
2580 project.set_active_path(Some(path.clone()), cx);
2581 project.open_buffer(path, cx)
2582 })
2583 .await
2584 .expect("failed to open active buffer");
2585
2586 set_collaborator_cursor(&active_buffer, 5, cx);
2587
2588 publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
2589
2590 cx.run_until_parked();
2591
2592 let cursor_point = Point::new(25, 0);
2593 let empty_search_range: Range<Point> = Default::default();
2594
2595 let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2596 let result = EditPredictionStore::next_diagnostic_location(
2597 active_buffer.clone(),
2598 &snapshot,
2599 empty_search_range.clone(),
2600 cursor_point,
2601 &project,
2602 &mut cx.to_async(),
2603 )
2604 .await
2605 .expect("next_diagnostic_location failed");
2606
2607 let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
2608 assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
2609 let result_row = result_buffer.read_with(cx, |buffer, _| {
2610 result_anchor.to_point(&buffer.snapshot()).row
2611 });
2612 assert_ne!(
2613 result_row, 3,
2614 "row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
2615 );
2616 assert!(
2617 result_row == 25 || result_row == 50,
2618 "expected row 25 or 50, got {result_row}"
2619 );
2620
2621 let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2622 let near_cursor_point = Point::new(4, 0);
2623 let result_near = EditPredictionStore::next_diagnostic_location(
2624 active_buffer.clone(),
2625 &snapshot_near,
2626 empty_search_range.clone(),
2627 near_cursor_point,
2628 &project,
2629 &mut cx.to_async(),
2630 )
2631 .await
2632 .expect("next_diagnostic_location failed");
2633
2634 let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
2635 let near_row =
2636 active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
2637 assert_eq!(
2638 near_row, 3,
2639 "row 3 should be included when local cursor (row 4) is also near the collaborator"
2640 );
2641
2642 let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2643 let far_cursor_point = Point::new(50, 0);
2644 let result_far = EditPredictionStore::next_diagnostic_location(
2645 active_buffer.clone(),
2646 &snapshot_far,
2647 empty_search_range.clone(),
2648 far_cursor_point,
2649 &project,
2650 &mut cx.to_async(),
2651 )
2652 .await
2653 .expect("next_diagnostic_location failed");
2654
2655 let (_, far_anchor) = result_far.expect("expected a diagnostic location");
2656 let far_row =
2657 active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
2658 assert_eq!(
2659 far_row, 50,
2660 "row 50 is near local cursor (row 50) and far from collaborator, should be picked"
2661 );
2662
2663 publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
2664 publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
2665 cx.run_until_parked();
2666
2667 let collab_buffer = project
2668 .update(cx, |project, cx| {
2669 let path = project
2670 .find_project_path(path!("/root/collab_file.txt"), cx)
2671 .expect("collab_file.txt not found");
2672 project.open_buffer(path, cx)
2673 })
2674 .await
2675 .expect("failed to open collab buffer");
2676
2677 set_collaborator_cursor(&collab_buffer, 0, cx);
2678 cx.run_until_parked();
2679
2680 let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
2681 let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2682 let result_cross = EditPredictionStore::next_diagnostic_location(
2683 active_buffer.clone(),
2684 &snapshot_cross,
2685 no_same_file_search_range,
2686 Point::new(0, 0),
2687 &project,
2688 &mut cx.to_async(),
2689 )
2690 .await
2691 .expect("cross-file next_diagnostic_location failed");
2692
2693 let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
2694 let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
2695 buffer
2696 .file()
2697 .expect("buffer should have a file")
2698 .full_path(cx)
2699 });
2700 assert_eq!(
2701 cross_path,
2702 Path::new(path!("root/free_file.txt")),
2703 "should skip collab_file.txt (has collaborator) and pick free_file.txt"
2704 );
2705}
2706
2707#[gpui::test]
2708async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
2709 let (ep_store, _requests) = init_test_with_fake_client(cx);
2710 let fs = FakeFs::new(cx.executor());
2711
2712 // Buffer with two clearly separated regions:
2713 // Region A = lines 0-9 (offsets 0..50)
2714 // Region B = lines 20-29 (offsets 105..155)
2715 // A big gap in between so edits in one region never overlap the other.
2716 let mut content = String::new();
2717 for i in 0..30 {
2718 content.push_str(&format!("line {i:02}\n"));
2719 }
2720
2721 fs.insert_tree(
2722 "/root",
2723 json!({
2724 "foo.md": content.clone()
2725 }),
2726 )
2727 .await;
2728 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2729
2730 let buffer = project
2731 .update(cx, |project, cx| {
2732 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2733 project.open_buffer(path, cx)
2734 })
2735 .await
2736 .unwrap();
2737
2738 type SettledEventRecord = (EditPredictionId, String);
2739 let settled_events: Arc<Mutex<Vec<SettledEventRecord>>> = Arc::new(Mutex::new(Vec::new()));
2740
2741 ep_store.update(cx, |ep_store, cx| {
2742 ep_store.register_buffer(&buffer, &project, cx);
2743
2744 let settled_events = settled_events.clone();
2745 ep_store.settled_event_callback = Some(Box::new(move |id, text| {
2746 settled_events.lock().push((id, text));
2747 }));
2748 });
2749
2750 // --- Phase 1: edit in region A and enqueue prediction A ---
2751
2752 buffer.update(cx, |buffer, cx| {
2753 // Edit at the start of line 0.
2754 buffer.edit(vec![(0..0, "ADDED ")], None, cx);
2755 });
2756 cx.run_until_parked();
2757
2758 let snapshot_a = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2759
2760 // Region A: first 10 lines of the buffer.
2761 let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0));
2762
2763 ep_store.update(cx, |ep_store, cx| {
2764 ep_store.enqueue_settled_prediction(
2765 EditPredictionId("prediction-a".into()),
2766 &project,
2767 &buffer,
2768 &snapshot_a,
2769 editable_region_a.clone(),
2770 None,
2771 cx,
2772 );
2773 });
2774
2775 // --- Phase 2: repeatedly edit in region A to keep it unsettled ---
2776
2777 // Let the worker process the channel message before we start advancing.
2778 cx.run_until_parked();
2779
2780 let mut region_a_edit_offset = 5;
2781 for _ in 0..3 {
2782 // Edit inside region A (not at the boundary) so `last_edit_at` is
2783 // updated before the worker's next wake.
2784 buffer.update(cx, |buffer, cx| {
2785 buffer.edit(
2786 vec![(region_a_edit_offset..region_a_edit_offset, "x")],
2787 None,
2788 cx,
2789 );
2790 });
2791 region_a_edit_offset += 1;
2792 cx.run_until_parked();
2793
2794 cx.executor()
2795 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 2);
2796 cx.run_until_parked();
2797 assert!(
2798 settled_events.lock().is_empty(),
2799 "no settled events should fire while region A is still being edited"
2800 );
2801 }
2802
2803 // Still nothing settled.
2804 assert!(settled_events.lock().is_empty());
2805
2806 // --- Phase 3: edit in distinct region B, enqueue prediction B ---
2807 // Advance a small amount so B's quiescence window starts later than A's,
2808 // but not so much that A settles (A's last edit was at the start of
2809 // iteration 3, and it needs a full Q to settle).
2810 cx.executor()
2811 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2812 cx.run_until_parked();
2813 assert!(settled_events.lock().is_empty());
2814
2815 let snapshot_b = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2816 let line_20_offset = snapshot_b.point_to_offset(Point::new(20, 0));
2817
2818 buffer.update(cx, |buffer, cx| {
2819 buffer.edit(vec![(line_20_offset..line_20_offset, "NEW ")], None, cx);
2820 });
2821 cx.run_until_parked();
2822
2823 let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2824 let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0));
2825
2826 ep_store.update(cx, |ep_store, cx| {
2827 ep_store.enqueue_settled_prediction(
2828 EditPredictionId("prediction-b".into()),
2829 &project,
2830 &buffer,
2831 &snapshot_b2,
2832 editable_region_b.clone(),
2833 None,
2834 cx,
2835 );
2836 });
2837
2838 cx.run_until_parked();
2839 assert!(
2840 settled_events.lock().is_empty(),
2841 "neither prediction should have settled yet"
2842 );
2843
2844 // --- Phase 4: let enough time pass for region A to settle ---
2845 // A's last edit was at T_a (during the last loop iteration). The worker is
2846 // sleeping until T_a + Q. We advance just enough to reach that wake time
2847 // (Q/4 since we already advanced Q/4 in phase 3 on top of the loop's
2848 // 3*Q/2). At that point A has been quiet for Q and settles, but B was
2849 // enqueued only Q/4 ago and stays pending.
2850 cx.executor()
2851 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2852 cx.run_until_parked();
2853
2854 {
2855 let events = settled_events.lock().clone();
2856 assert_eq!(
2857 events.len(),
2858 1,
2859 "prediction and capture_sample for A should have settled, got: {events:?}"
2860 );
2861 assert_eq!(events[0].0, EditPredictionId("prediction-a".into()));
2862 }
2863
2864 // --- Phase 5: let more time pass for region B to settle ---
2865 // B's last edit was Q/4 before A settled. The worker rescheduled to
2866 // B's last_edit_at + Q, which is 3Q/4 from now.
2867 cx.executor()
2868 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE * 3 / 4);
2869 cx.run_until_parked();
2870
2871 {
2872 let events = settled_events.lock().clone();
2873 assert_eq!(
2874 events.len(),
2875 2,
2876 "both prediction and capture_sample settled events should be emitted for each request, got: {events:?}"
2877 );
2878 assert_eq!(events[1].0, EditPredictionId("prediction-b".into()));
2879 }
2880}
2881
2882#[ctor::ctor]
2883fn init_logger() {
2884 zlog::init_test();
2885}