1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
3use client::{UserStore, test::FakeServer};
4use clock::FakeSystemClock;
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
8 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
9};
10use futures::{
11 AsyncReadExt, FutureExt, StreamExt,
12 channel::{mpsc, oneshot},
13};
14use gpui::App;
15use gpui::{
16 Entity, TestAppContext,
17 http_client::{FakeHttpClient, Response},
18};
19use indoc::indoc;
20use language::{Anchor, Buffer, CursorShape, Operation, Point, Selection, SelectionGoal};
21use lsp::LanguageServerId;
22use parking_lot::Mutex;
23use pretty_assertions::{assert_eq, assert_matches};
24use project::{FakeFs, Project};
25use serde_json::json;
26use settings::SettingsStore;
27use std::{path::Path, sync::Arc, time::Duration};
28use util::path;
29use uuid::Uuid;
30use zeta_prompt::ZetaPromptInput;
31
32use crate::{
33 BufferEditPrediction, EDIT_PREDICTION_SETTLED_QUIESCENCE, EditPredictionId,
34 EditPredictionStore, REJECT_REQUEST_DEBOUNCE,
35};
36
37#[gpui::test]
38async fn test_current_state(cx: &mut TestAppContext) {
39 let (ep_store, mut requests) = init_test_with_fake_client(cx);
40 let fs = FakeFs::new(cx.executor());
41 fs.insert_tree(
42 "/root",
43 json!({
44 "1.txt": "Hello!\nHow\nBye\n",
45 "2.txt": "Hola!\nComo\nAdios\n"
46 }),
47 )
48 .await;
49 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
50
51 let buffer1 = project
52 .update(cx, |project, cx| {
53 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
54 project.set_active_path(Some(path.clone()), cx);
55 project.open_buffer(path, cx)
56 })
57 .await
58 .unwrap();
59 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
60 let position = snapshot1.anchor_before(language::Point::new(1, 3));
61
62 ep_store.update(cx, |ep_store, cx| {
63 ep_store.register_project(&project, cx);
64 ep_store.register_buffer(&buffer1, &project, cx);
65 });
66
67 // Prediction for current file
68
69 ep_store.update(cx, |ep_store, cx| {
70 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
71 });
72 let (request, respond_tx) = requests.predict.next().await.unwrap();
73
74 respond_tx
75 .send(model_response(
76 &request,
77 indoc! {r"
78 --- a/root/1.txt
79 +++ b/root/1.txt
80 @@ ... @@
81 Hello!
82 -How
83 +How are you?
84 Bye
85 "},
86 ))
87 .unwrap();
88
89 cx.run_until_parked();
90
91 ep_store.update(cx, |ep_store, cx| {
92 let prediction = ep_store
93 .prediction_at(&buffer1, None, &project, cx)
94 .unwrap();
95 assert_matches!(prediction, BufferEditPrediction::Local { .. });
96 });
97
98 ep_store.update(cx, |ep_store, cx| {
99 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
100 });
101
102 // Prediction for diagnostic in another file
103
104 let diagnostic = lsp::Diagnostic {
105 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
106 severity: Some(lsp::DiagnosticSeverity::ERROR),
107 message: "Sentence is incomplete".to_string(),
108 ..Default::default()
109 };
110
111 project.update(cx, |project, cx| {
112 project.lsp_store().update(cx, |lsp_store, cx| {
113 lsp_store
114 .update_diagnostics(
115 LanguageServerId(0),
116 lsp::PublishDiagnosticsParams {
117 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
118 diagnostics: vec![diagnostic],
119 version: None,
120 },
121 None,
122 language::DiagnosticSourceKind::Pushed,
123 &[],
124 cx,
125 )
126 .unwrap();
127 });
128 });
129
130 let (request, respond_tx) = requests.predict.next().await.unwrap();
131 respond_tx
132 .send(model_response(
133 &request,
134 indoc! {r#"
135 --- a/root/2.txt
136 +++ b/root/2.txt
137 @@ ... @@
138 Hola!
139 -Como
140 +Como estas?
141 Adios
142 "#},
143 ))
144 .unwrap();
145 cx.run_until_parked();
146
147 ep_store.update(cx, |ep_store, cx| {
148 let prediction = ep_store
149 .prediction_at(&buffer1, None, &project, cx)
150 .unwrap();
151 assert_matches!(
152 prediction,
153 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
154 );
155 });
156
157 let buffer2 = project
158 .update(cx, |project, cx| {
159 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
160 project.open_buffer(path, cx)
161 })
162 .await
163 .unwrap();
164
165 ep_store.update(cx, |ep_store, cx| {
166 let prediction = ep_store
167 .prediction_at(&buffer2, None, &project, cx)
168 .unwrap();
169 assert_matches!(prediction, BufferEditPrediction::Local { .. });
170 });
171}
172
173#[gpui::test]
174async fn test_simple_request(cx: &mut TestAppContext) {
175 let (ep_store, mut requests) = init_test_with_fake_client(cx);
176 let fs = FakeFs::new(cx.executor());
177 fs.insert_tree(
178 "/root",
179 json!({
180 "foo.md": "Hello!\nHow\nBye\n"
181 }),
182 )
183 .await;
184 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
185
186 let buffer = project
187 .update(cx, |project, cx| {
188 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
189 project.open_buffer(path, cx)
190 })
191 .await
192 .unwrap();
193 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
194 let position = snapshot.anchor_before(language::Point::new(1, 3));
195
196 let prediction_task = ep_store.update(cx, |ep_store, cx| {
197 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
198 });
199
200 let (request, respond_tx) = requests.predict.next().await.unwrap();
201
202 // TODO Put back when we have a structured request again
203 // assert_eq!(
204 // request.excerpt_path.as_ref(),
205 // Path::new(path!("root/foo.md"))
206 // );
207 // assert_eq!(
208 // request.cursor_point,
209 // Point {
210 // line: Line(1),
211 // column: 3
212 // }
213 // );
214
215 respond_tx
216 .send(model_response(
217 &request,
218 indoc! { r"
219 --- a/root/foo.md
220 +++ b/root/foo.md
221 @@ ... @@
222 Hello!
223 -How
224 +How are you?
225 Bye
226 "},
227 ))
228 .unwrap();
229
230 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
231
232 assert_eq!(prediction.edits.len(), 1);
233 assert_eq!(
234 prediction.edits[0].0.to_point(&snapshot).start,
235 language::Point::new(1, 3)
236 );
237 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
238}
239
240#[gpui::test]
241async fn test_request_events(cx: &mut TestAppContext) {
242 let (ep_store, mut requests) = init_test_with_fake_client(cx);
243 let fs = FakeFs::new(cx.executor());
244 fs.insert_tree(
245 "/root",
246 json!({
247 "foo.md": "Hello!\n\nBye\n"
248 }),
249 )
250 .await;
251 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
252
253 let buffer = project
254 .update(cx, |project, cx| {
255 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
256 project.open_buffer(path, cx)
257 })
258 .await
259 .unwrap();
260
261 ep_store.update(cx, |ep_store, cx| {
262 ep_store.register_buffer(&buffer, &project, cx);
263 });
264
265 buffer.update(cx, |buffer, cx| {
266 buffer.edit(vec![(7..7, "How")], None, cx);
267 });
268
269 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
270 let position = snapshot.anchor_before(language::Point::new(1, 3));
271
272 let prediction_task = ep_store.update(cx, |ep_store, cx| {
273 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
274 });
275
276 let (request, respond_tx) = requests.predict.next().await.unwrap();
277
278 let prompt = prompt_from_request(&request);
279 assert!(
280 prompt.contains(indoc! {"
281 --- a/root/foo.md
282 +++ b/root/foo.md
283 @@ -1,3 +1,3 @@
284 Hello!
285 -
286 +How
287 Bye
288 "}),
289 "{prompt}"
290 );
291
292 respond_tx
293 .send(model_response(
294 &request,
295 indoc! {r#"
296 --- a/root/foo.md
297 +++ b/root/foo.md
298 @@ ... @@
299 Hello!
300 -How
301 +How are you?
302 Bye
303 "#},
304 ))
305 .unwrap();
306
307 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
308
309 assert_eq!(prediction.edits.len(), 1);
310 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
311}
312
313#[gpui::test]
314async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
315 let (ep_store, _requests) = init_test_with_fake_client(cx);
316 let fs = FakeFs::new(cx.executor());
317 fs.insert_tree(
318 "/root",
319 json!({
320 "foo.md": "Hello!\n\nBye\n"
321 }),
322 )
323 .await;
324 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
325
326 let buffer = project
327 .update(cx, |project, cx| {
328 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
329 project.open_buffer(path, cx)
330 })
331 .await
332 .unwrap();
333
334 ep_store.update(cx, |ep_store, cx| {
335 ep_store.register_buffer(&buffer, &project, cx);
336 });
337
338 // First burst: insert "How"
339 buffer.update(cx, |buffer, cx| {
340 buffer.edit(vec![(7..7, "How")], None, cx);
341 });
342
343 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
344 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
345 cx.run_until_parked();
346
347 // Second burst: append " are you?" immediately after "How" on the same line.
348 //
349 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
350 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
351 buffer.update(cx, |buffer, cx| {
352 buffer.edit(vec![(10..10, " are you?")], None, cx);
353 });
354
355 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
356 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
357 buffer.update(cx, |buffer, cx| {
358 buffer.edit(vec![(19..19, "!")], None, cx);
359 });
360
361 // With time-based splitting, there are two distinct events.
362 let events = ep_store.update(cx, |ep_store, cx| {
363 ep_store.edit_history_for_project(&project, cx)
364 });
365 assert_eq!(events.len(), 2);
366 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
367 assert_eq!(
368 diff.as_str(),
369 indoc! {"
370 @@ -1,3 +1,3 @@
371 Hello!
372 -
373 +How
374 Bye
375 "}
376 );
377
378 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
379 assert_eq!(
380 diff.as_str(),
381 indoc! {"
382 @@ -1,3 +1,3 @@
383 Hello!
384 -How
385 +How are you?!
386 Bye
387 "}
388 );
389}
390
391#[gpui::test]
392async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
393 let (ep_store, _requests) = init_test_with_fake_client(cx);
394 let fs = FakeFs::new(cx.executor());
395
396 // Create a file with 30 lines to test line-based coalescing
397 let content = (1..=30)
398 .map(|i| format!("Line {}\n", i))
399 .collect::<String>();
400 fs.insert_tree(
401 "/root",
402 json!({
403 "foo.md": content
404 }),
405 )
406 .await;
407 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
408
409 let buffer = project
410 .update(cx, |project, cx| {
411 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
412 project.open_buffer(path, cx)
413 })
414 .await
415 .unwrap();
416
417 ep_store.update(cx, |ep_store, cx| {
418 ep_store.register_buffer(&buffer, &project, cx);
419 });
420
421 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
422 buffer.update(cx, |buffer, cx| {
423 let start = Point::new(10, 0).to_offset(buffer);
424 let end = Point::new(13, 0).to_offset(buffer);
425 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
426 });
427
428 let events = ep_store.update(cx, |ep_store, cx| {
429 ep_store.edit_history_for_project(&project, cx)
430 });
431 assert_eq!(
432 render_events(&events),
433 indoc! {"
434 @@ -8,9 +8,8 @@
435 Line 8
436 Line 9
437 Line 10
438 -Line 11
439 -Line 12
440 -Line 13
441 +Middle A
442 +Middle B
443 Line 14
444 Line 15
445 Line 16
446 "},
447 "After first edit"
448 );
449
450 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
451 // This tests that coalescing considers the START of the existing range
452 buffer.update(cx, |buffer, cx| {
453 let offset = Point::new(5, 0).to_offset(buffer);
454 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
455 });
456
457 let events = ep_store.update(cx, |ep_store, cx| {
458 ep_store.edit_history_for_project(&project, cx)
459 });
460 assert_eq!(
461 render_events(&events),
462 indoc! {"
463 @@ -3,14 +3,14 @@
464 Line 3
465 Line 4
466 Line 5
467 +Above
468 Line 6
469 Line 7
470 Line 8
471 Line 9
472 Line 10
473 -Line 11
474 -Line 12
475 -Line 13
476 +Middle A
477 +Middle B
478 Line 14
479 Line 15
480 Line 16
481 "},
482 "After inserting above (should coalesce)"
483 );
484
485 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
486 // This tests that coalescing considers the END of the existing range
487 buffer.update(cx, |buffer, cx| {
488 let offset = Point::new(14, 0).to_offset(buffer);
489 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
490 });
491
492 let events = ep_store.update(cx, |ep_store, cx| {
493 ep_store.edit_history_for_project(&project, cx)
494 });
495 assert_eq!(
496 render_events(&events),
497 indoc! {"
498 @@ -3,15 +3,16 @@
499 Line 3
500 Line 4
501 Line 5
502 +Above
503 Line 6
504 Line 7
505 Line 8
506 Line 9
507 Line 10
508 -Line 11
509 -Line 12
510 -Line 13
511 +Middle A
512 +Middle B
513 Line 14
514 +Below
515 Line 15
516 Line 16
517 Line 17
518 "},
519 "After inserting below (should coalesce)"
520 );
521
522 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
523 // This should NOT coalesce - creates a new event
524 buffer.update(cx, |buffer, cx| {
525 let offset = Point::new(25, 0).to_offset(buffer);
526 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
527 });
528
529 let events = ep_store.update(cx, |ep_store, cx| {
530 ep_store.edit_history_for_project(&project, cx)
531 });
532 assert_eq!(
533 render_events(&events),
534 indoc! {"
535 @@ -3,15 +3,16 @@
536 Line 3
537 Line 4
538 Line 5
539 +Above
540 Line 6
541 Line 7
542 Line 8
543 Line 9
544 Line 10
545 -Line 11
546 -Line 12
547 -Line 13
548 +Middle A
549 +Middle B
550 Line 14
551 +Below
552 Line 15
553 Line 16
554 Line 17
555
556 ---
557 @@ -23,6 +23,7 @@
558 Line 22
559 Line 23
560 Line 24
561 +Far below
562 Line 25
563 Line 26
564 Line 27
565 "},
566 "After inserting far below (should NOT coalesce)"
567 );
568}
569
570fn render_events(events: &[StoredEvent]) -> String {
571 events
572 .iter()
573 .map(|e| {
574 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
575 diff.as_str()
576 })
577 .collect::<Vec<_>>()
578 .join("\n---\n")
579}
580
581fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
582 events
583 .iter()
584 .map(|e| {
585 let zeta_prompt::Event::BufferChange {
586 diff, predicted, ..
587 } = e.event.as_ref();
588 let prefix = if *predicted { "predicted" } else { "manual" };
589 format!("{}\n{}", prefix, diff)
590 })
591 .collect()
592}
593
594#[gpui::test]
595async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
596 let (ep_store, _requests) = init_test_with_fake_client(cx);
597 let fs = FakeFs::new(cx.executor());
598 fs.insert_tree(
599 "/root",
600 json!({
601 "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
602 }),
603 )
604 .await;
605 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
606
607 let buffer = project
608 .update(cx, |project, cx| {
609 let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
610 project.open_buffer(path, cx)
611 })
612 .await
613 .unwrap();
614
615 ep_store.update(cx, |ep_store, cx| {
616 ep_store.register_buffer(&buffer, &project, cx);
617 });
618
619 // Case 1: Manual edits have `predicted` set to false.
620 buffer.update(cx, |buffer, cx| {
621 buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
622 });
623
624 let events = ep_store.update(cx, |ep_store, cx| {
625 ep_store.edit_history_for_project(&project, cx)
626 });
627
628 assert_eq!(
629 render_events_with_predicted(&events),
630 vec![indoc! {"
631 manual
632 @@ -1,4 +1,4 @@
633 -line 0
634 +LINE ZERO
635 line 1
636 line 2
637 line 3
638 "}]
639 );
640
641 // Case 2: Multiple successive manual edits near each other are merged into one
642 // event with `predicted` set to false.
643 buffer.update(cx, |buffer, cx| {
644 let offset = Point::new(1, 0).to_offset(buffer);
645 let end = Point::new(1, 6).to_offset(buffer);
646 buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
647 });
648
649 let events = ep_store.update(cx, |ep_store, cx| {
650 ep_store.edit_history_for_project(&project, cx)
651 });
652 assert_eq!(
653 render_events_with_predicted(&events),
654 vec![indoc! {"
655 manual
656 @@ -1,5 +1,5 @@
657 -line 0
658 -line 1
659 +LINE ZERO
660 +LINE ONE
661 line 2
662 line 3
663 line 4
664 "}]
665 );
666
667 // Case 3: Accepted predictions have `predicted` set to true.
668 // Case 5: A manual edit that follows a predicted edit is not merged with the
669 // predicted edit, even if it is nearby.
670 ep_store.update(cx, |ep_store, cx| {
671 buffer.update(cx, |buffer, cx| {
672 let offset = Point::new(2, 0).to_offset(buffer);
673 let end = Point::new(2, 6).to_offset(buffer);
674 buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
675 });
676 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
677 });
678
679 let events = ep_store.update(cx, |ep_store, cx| {
680 ep_store.edit_history_for_project(&project, cx)
681 });
682 assert_eq!(
683 render_events_with_predicted(&events),
684 vec![
685 indoc! {"
686 manual
687 @@ -1,5 +1,5 @@
688 -line 0
689 -line 1
690 +LINE ZERO
691 +LINE ONE
692 line 2
693 line 3
694 line 4
695 "},
696 indoc! {"
697 predicted
698 @@ -1,6 +1,6 @@
699 LINE ZERO
700 LINE ONE
701 -line 2
702 +LINE TWO
703 line 3
704 line 4
705 line 5
706 "}
707 ]
708 );
709
710 // Case 4: Multiple successive accepted predictions near each other are merged
711 // into one event with `predicted` set to true.
712 ep_store.update(cx, |ep_store, cx| {
713 buffer.update(cx, |buffer, cx| {
714 let offset = Point::new(3, 0).to_offset(buffer);
715 let end = Point::new(3, 6).to_offset(buffer);
716 buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
717 });
718 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
719 });
720
721 let events = ep_store.update(cx, |ep_store, cx| {
722 ep_store.edit_history_for_project(&project, cx)
723 });
724 assert_eq!(
725 render_events_with_predicted(&events),
726 vec![
727 indoc! {"
728 manual
729 @@ -1,5 +1,5 @@
730 -line 0
731 -line 1
732 +LINE ZERO
733 +LINE ONE
734 line 2
735 line 3
736 line 4
737 "},
738 indoc! {"
739 predicted
740 @@ -1,7 +1,7 @@
741 LINE ZERO
742 LINE ONE
743 -line 2
744 -line 3
745 +LINE TWO
746 +LINE THREE
747 line 4
748 line 5
749 line 6
750 "}
751 ]
752 );
753
754 // Case 5 (continued): A manual edit that follows a predicted edit is not merged
755 // with the predicted edit, even if it is nearby.
756 buffer.update(cx, |buffer, cx| {
757 let offset = Point::new(4, 0).to_offset(buffer);
758 let end = Point::new(4, 6).to_offset(buffer);
759 buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
760 });
761
762 let events = ep_store.update(cx, |ep_store, cx| {
763 ep_store.edit_history_for_project(&project, cx)
764 });
765 assert_eq!(
766 render_events_with_predicted(&events),
767 vec![
768 indoc! {"
769 manual
770 @@ -1,5 +1,5 @@
771 -line 0
772 -line 1
773 +LINE ZERO
774 +LINE ONE
775 line 2
776 line 3
777 line 4
778 "},
779 indoc! {"
780 predicted
781 @@ -1,7 +1,7 @@
782 LINE ZERO
783 LINE ONE
784 -line 2
785 -line 3
786 +LINE TWO
787 +LINE THREE
788 line 4
789 line 5
790 line 6
791 "},
792 indoc! {"
793 manual
794 @@ -2,7 +2,7 @@
795 LINE ONE
796 LINE TWO
797 LINE THREE
798 -line 4
799 +LINE FOUR
800 line 5
801 line 6
802 line 7
803 "}
804 ]
805 );
806
807 // Case 6: If we then perform a manual edit at a *different* location (more than
808 // 8 lines away), then the edits at the prior location can be merged with each
809 // other, even if some are predicted and some are not. `predicted` means all
810 // constituent edits were predicted.
811 buffer.update(cx, |buffer, cx| {
812 let offset = Point::new(14, 0).to_offset(buffer);
813 let end = Point::new(14, 7).to_offset(buffer);
814 buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
815 });
816
817 let events = ep_store.update(cx, |ep_store, cx| {
818 ep_store.edit_history_for_project(&project, cx)
819 });
820 assert_eq!(
821 render_events_with_predicted(&events),
822 vec![
823 indoc! {"
824 manual
825 @@ -1,8 +1,8 @@
826 -line 0
827 -line 1
828 -line 2
829 -line 3
830 -line 4
831 +LINE ZERO
832 +LINE ONE
833 +LINE TWO
834 +LINE THREE
835 +LINE FOUR
836 line 5
837 line 6
838 line 7
839 "},
840 indoc! {"
841 manual
842 @@ -12,4 +12,4 @@
843 line 11
844 line 12
845 line 13
846 -line 14
847 +LINE FOURTEEN
848 "}
849 ]
850 );
851}
852
853#[gpui::test]
854async fn test_empty_prediction(cx: &mut TestAppContext) {
855 let (ep_store, mut requests) = init_test_with_fake_client(cx);
856 let fs = FakeFs::new(cx.executor());
857 fs.insert_tree(
858 "/root",
859 json!({
860 "foo.md": "Hello!\nHow\nBye\n"
861 }),
862 )
863 .await;
864 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
865
866 let buffer = project
867 .update(cx, |project, cx| {
868 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
869 project.open_buffer(path, cx)
870 })
871 .await
872 .unwrap();
873 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
874 let position = snapshot.anchor_before(language::Point::new(1, 3));
875
876 ep_store.update(cx, |ep_store, cx| {
877 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
878 });
879
880 let (request, respond_tx) = requests.predict.next().await.unwrap();
881 let response = model_response(&request, "");
882 let id = response.request_id.clone();
883 respond_tx.send(response).unwrap();
884
885 cx.run_until_parked();
886
887 ep_store.update(cx, |ep_store, cx| {
888 assert!(
889 ep_store
890 .prediction_at(&buffer, None, &project, cx)
891 .is_none()
892 );
893 });
894
895 // prediction is reported as rejected
896 let (reject_request, _) = requests.reject.next().await.unwrap();
897
898 assert_eq!(
899 &reject_request.rejections,
900 &[EditPredictionRejection {
901 request_id: id,
902 reason: EditPredictionRejectReason::Empty,
903 was_shown: false,
904 model_version: None,
905 }]
906 );
907}
908
909#[gpui::test]
910async fn test_interpolated_empty(cx: &mut TestAppContext) {
911 let (ep_store, mut requests) = init_test_with_fake_client(cx);
912 let fs = FakeFs::new(cx.executor());
913 fs.insert_tree(
914 "/root",
915 json!({
916 "foo.md": "Hello!\nHow\nBye\n"
917 }),
918 )
919 .await;
920 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
921
922 let buffer = project
923 .update(cx, |project, cx| {
924 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
925 project.open_buffer(path, cx)
926 })
927 .await
928 .unwrap();
929 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
930 let position = snapshot.anchor_before(language::Point::new(1, 3));
931
932 ep_store.update(cx, |ep_store, cx| {
933 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
934 });
935
936 let (request, respond_tx) = requests.predict.next().await.unwrap();
937
938 buffer.update(cx, |buffer, cx| {
939 buffer.set_text("Hello!\nHow are you?\nBye", cx);
940 });
941
942 let response = model_response(&request, SIMPLE_DIFF);
943 let id = response.request_id.clone();
944 respond_tx.send(response).unwrap();
945
946 cx.run_until_parked();
947
948 ep_store.update(cx, |ep_store, cx| {
949 assert!(
950 ep_store
951 .prediction_at(&buffer, None, &project, cx)
952 .is_none()
953 );
954 });
955
956 // prediction is reported as rejected
957 let (reject_request, _) = requests.reject.next().await.unwrap();
958
959 assert_eq!(
960 &reject_request.rejections,
961 &[EditPredictionRejection {
962 request_id: id,
963 reason: EditPredictionRejectReason::InterpolatedEmpty,
964 was_shown: false,
965 model_version: None,
966 }]
967 );
968}
969
970const SIMPLE_DIFF: &str = indoc! { r"
971 --- a/root/foo.md
972 +++ b/root/foo.md
973 @@ ... @@
974 Hello!
975 -How
976 +How are you?
977 Bye
978"};
979
980#[gpui::test]
981async fn test_replace_current(cx: &mut TestAppContext) {
982 let (ep_store, mut requests) = init_test_with_fake_client(cx);
983 let fs = FakeFs::new(cx.executor());
984 fs.insert_tree(
985 "/root",
986 json!({
987 "foo.md": "Hello!\nHow\nBye\n"
988 }),
989 )
990 .await;
991 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
992
993 let buffer = project
994 .update(cx, |project, cx| {
995 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
996 project.open_buffer(path, cx)
997 })
998 .await
999 .unwrap();
1000 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1001 let position = snapshot.anchor_before(language::Point::new(1, 3));
1002
1003 ep_store.update(cx, |ep_store, cx| {
1004 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1005 });
1006
1007 let (request, respond_tx) = requests.predict.next().await.unwrap();
1008 let first_response = model_response(&request, SIMPLE_DIFF);
1009 let first_id = first_response.request_id.clone();
1010 respond_tx.send(first_response).unwrap();
1011
1012 cx.run_until_parked();
1013
1014 ep_store.update(cx, |ep_store, cx| {
1015 assert_eq!(
1016 ep_store
1017 .prediction_at(&buffer, None, &project, cx)
1018 .unwrap()
1019 .id
1020 .0,
1021 first_id
1022 );
1023 });
1024
1025 // a second request is triggered
1026 ep_store.update(cx, |ep_store, cx| {
1027 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1028 });
1029
1030 let (request, respond_tx) = requests.predict.next().await.unwrap();
1031 let second_response = model_response(&request, SIMPLE_DIFF);
1032 let second_id = second_response.request_id.clone();
1033 respond_tx.send(second_response).unwrap();
1034
1035 cx.run_until_parked();
1036
1037 ep_store.update(cx, |ep_store, cx| {
1038 // second replaces first
1039 assert_eq!(
1040 ep_store
1041 .prediction_at(&buffer, None, &project, cx)
1042 .unwrap()
1043 .id
1044 .0,
1045 second_id
1046 );
1047 });
1048
1049 // first is reported as replaced
1050 let (reject_request, _) = requests.reject.next().await.unwrap();
1051
1052 assert_eq!(
1053 &reject_request.rejections,
1054 &[EditPredictionRejection {
1055 request_id: first_id,
1056 reason: EditPredictionRejectReason::Replaced,
1057 was_shown: false,
1058 model_version: None,
1059 }]
1060 );
1061}
1062
1063#[gpui::test]
1064async fn test_current_preferred(cx: &mut TestAppContext) {
1065 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1066 let fs = FakeFs::new(cx.executor());
1067 fs.insert_tree(
1068 "/root",
1069 json!({
1070 "foo.md": "Hello!\nHow\nBye\n"
1071 }),
1072 )
1073 .await;
1074 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1075
1076 let buffer = project
1077 .update(cx, |project, cx| {
1078 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1079 project.open_buffer(path, cx)
1080 })
1081 .await
1082 .unwrap();
1083 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1084 let position = snapshot.anchor_before(language::Point::new(1, 3));
1085
1086 ep_store.update(cx, |ep_store, cx| {
1087 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1088 });
1089
1090 let (request, respond_tx) = requests.predict.next().await.unwrap();
1091 let first_response = model_response(&request, SIMPLE_DIFF);
1092 let first_id = first_response.request_id.clone();
1093 respond_tx.send(first_response).unwrap();
1094
1095 cx.run_until_parked();
1096
1097 ep_store.update(cx, |ep_store, cx| {
1098 assert_eq!(
1099 ep_store
1100 .prediction_at(&buffer, None, &project, cx)
1101 .unwrap()
1102 .id
1103 .0,
1104 first_id
1105 );
1106 });
1107
1108 // a second request is triggered
1109 ep_store.update(cx, |ep_store, cx| {
1110 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1111 });
1112
1113 let (request, respond_tx) = requests.predict.next().await.unwrap();
1114 // worse than current prediction
1115 let second_response = model_response(
1116 &request,
1117 indoc! { r"
1118 --- a/root/foo.md
1119 +++ b/root/foo.md
1120 @@ ... @@
1121 Hello!
1122 -How
1123 +How are
1124 Bye
1125 "},
1126 );
1127 let second_id = second_response.request_id.clone();
1128 respond_tx.send(second_response).unwrap();
1129
1130 cx.run_until_parked();
1131
1132 ep_store.update(cx, |ep_store, cx| {
1133 // first is preferred over second
1134 assert_eq!(
1135 ep_store
1136 .prediction_at(&buffer, None, &project, cx)
1137 .unwrap()
1138 .id
1139 .0,
1140 first_id
1141 );
1142 });
1143
1144 // second is reported as rejected
1145 let (reject_request, _) = requests.reject.next().await.unwrap();
1146
1147 assert_eq!(
1148 &reject_request.rejections,
1149 &[EditPredictionRejection {
1150 request_id: second_id,
1151 reason: EditPredictionRejectReason::CurrentPreferred,
1152 was_shown: false,
1153 model_version: None,
1154 }]
1155 );
1156}
1157
1158#[gpui::test]
1159async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1160 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1161 let fs = FakeFs::new(cx.executor());
1162 fs.insert_tree(
1163 "/root",
1164 json!({
1165 "foo.md": "Hello!\nHow\nBye\n"
1166 }),
1167 )
1168 .await;
1169 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1170
1171 let buffer = project
1172 .update(cx, |project, cx| {
1173 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1174 project.open_buffer(path, cx)
1175 })
1176 .await
1177 .unwrap();
1178 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1179 let position = snapshot.anchor_before(language::Point::new(1, 3));
1180
1181 // start two refresh tasks
1182 ep_store.update(cx, |ep_store, cx| {
1183 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1184 });
1185
1186 let (request1, respond_first) = requests.predict.next().await.unwrap();
1187
1188 ep_store.update(cx, |ep_store, cx| {
1189 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1190 });
1191
1192 let (request, respond_second) = requests.predict.next().await.unwrap();
1193
1194 // wait for throttle
1195 cx.run_until_parked();
1196
1197 // second responds first
1198 let second_response = model_response(&request, SIMPLE_DIFF);
1199 let second_id = second_response.request_id.clone();
1200 respond_second.send(second_response).unwrap();
1201
1202 cx.run_until_parked();
1203
1204 ep_store.update(cx, |ep_store, cx| {
1205 // current prediction is second
1206 assert_eq!(
1207 ep_store
1208 .prediction_at(&buffer, None, &project, cx)
1209 .unwrap()
1210 .id
1211 .0,
1212 second_id
1213 );
1214 });
1215
1216 let first_response = model_response(&request1, SIMPLE_DIFF);
1217 let first_id = first_response.request_id.clone();
1218 respond_first.send(first_response).unwrap();
1219
1220 cx.run_until_parked();
1221
1222 ep_store.update(cx, |ep_store, cx| {
1223 // current prediction is still second, since first was cancelled
1224 assert_eq!(
1225 ep_store
1226 .prediction_at(&buffer, None, &project, cx)
1227 .unwrap()
1228 .id
1229 .0,
1230 second_id
1231 );
1232 });
1233
1234 // first is reported as rejected
1235 let (reject_request, _) = requests.reject.next().await.unwrap();
1236
1237 cx.run_until_parked();
1238
1239 assert_eq!(
1240 &reject_request.rejections,
1241 &[EditPredictionRejection {
1242 request_id: first_id,
1243 reason: EditPredictionRejectReason::Canceled,
1244 was_shown: false,
1245 model_version: None,
1246 }]
1247 );
1248}
1249
1250#[gpui::test]
1251async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1252 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1253 let fs = FakeFs::new(cx.executor());
1254 fs.insert_tree(
1255 "/root",
1256 json!({
1257 "foo.md": "Hello!\nHow\nBye\n"
1258 }),
1259 )
1260 .await;
1261 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1262
1263 let buffer = project
1264 .update(cx, |project, cx| {
1265 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1266 project.open_buffer(path, cx)
1267 })
1268 .await
1269 .unwrap();
1270 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1271 let position = snapshot.anchor_before(language::Point::new(1, 3));
1272
1273 // start two refresh tasks
1274 ep_store.update(cx, |ep_store, cx| {
1275 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1276 });
1277
1278 let (request1, respond_first) = requests.predict.next().await.unwrap();
1279
1280 ep_store.update(cx, |ep_store, cx| {
1281 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1282 });
1283
1284 let (request2, respond_second) = requests.predict.next().await.unwrap();
1285
1286 // wait for throttle, so requests are sent
1287 cx.run_until_parked();
1288
1289 ep_store.update(cx, |ep_store, cx| {
1290 // start a third request
1291 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1292
1293 // 2 are pending, so 2nd is cancelled
1294 assert_eq!(
1295 ep_store
1296 .get_or_init_project(&project, cx)
1297 .cancelled_predictions
1298 .iter()
1299 .copied()
1300 .collect::<Vec<_>>(),
1301 [1]
1302 );
1303 });
1304
1305 // wait for throttle
1306 cx.run_until_parked();
1307
1308 let (request3, respond_third) = requests.predict.next().await.unwrap();
1309
1310 let first_response = model_response(&request1, SIMPLE_DIFF);
1311 let first_id = first_response.request_id.clone();
1312 respond_first.send(first_response).unwrap();
1313
1314 cx.run_until_parked();
1315
1316 ep_store.update(cx, |ep_store, cx| {
1317 // current prediction is first
1318 assert_eq!(
1319 ep_store
1320 .prediction_at(&buffer, None, &project, cx)
1321 .unwrap()
1322 .id
1323 .0,
1324 first_id
1325 );
1326 });
1327
1328 let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1329 let cancelled_id = cancelled_response.request_id.clone();
1330 respond_second.send(cancelled_response).unwrap();
1331
1332 cx.run_until_parked();
1333
1334 ep_store.update(cx, |ep_store, cx| {
1335 // current prediction is still first, since second was cancelled
1336 assert_eq!(
1337 ep_store
1338 .prediction_at(&buffer, None, &project, cx)
1339 .unwrap()
1340 .id
1341 .0,
1342 first_id
1343 );
1344 });
1345
1346 let third_response = model_response(&request3, SIMPLE_DIFF);
1347 let third_response_id = third_response.request_id.clone();
1348 respond_third.send(third_response).unwrap();
1349
1350 cx.run_until_parked();
1351
1352 ep_store.update(cx, |ep_store, cx| {
1353 // third completes and replaces first
1354 assert_eq!(
1355 ep_store
1356 .prediction_at(&buffer, None, &project, cx)
1357 .unwrap()
1358 .id
1359 .0,
1360 third_response_id
1361 );
1362 });
1363
1364 // second is reported as rejected
1365 let (reject_request, _) = requests.reject.next().await.unwrap();
1366
1367 cx.run_until_parked();
1368
1369 assert_eq!(
1370 &reject_request.rejections,
1371 &[
1372 EditPredictionRejection {
1373 request_id: cancelled_id,
1374 reason: EditPredictionRejectReason::Canceled,
1375 was_shown: false,
1376 model_version: None,
1377 },
1378 EditPredictionRejection {
1379 request_id: first_id,
1380 reason: EditPredictionRejectReason::Replaced,
1381 was_shown: false,
1382 model_version: None,
1383 }
1384 ]
1385 );
1386}
1387
1388#[gpui::test]
1389async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
1390 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1391
1392 let fs = FakeFs::new(cx.executor());
1393 fs.insert_tree(
1394 "/root",
1395 json!({
1396 "foo.md": "Hello!\nHow\nBye\n",
1397 "bar.md": "Hola!\nComo\nAdios\n"
1398 }),
1399 )
1400 .await;
1401 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1402
1403 let buffer = project
1404 .update(cx, |project, cx| {
1405 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1406 project.set_active_path(Some(path.clone()), cx);
1407 project.open_buffer(path, cx)
1408 })
1409 .await
1410 .unwrap();
1411 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1412 let position = snapshot.anchor_before(language::Point::new(1, 3));
1413
1414 ep_store.update(cx, |ep_store, cx| {
1415 ep_store.register_project(&project, cx);
1416 ep_store.register_buffer(&buffer, &project, cx);
1417 });
1418
1419 // First edit request - no prior edit, so not throttled.
1420 ep_store.update(cx, |ep_store, cx| {
1421 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1422 });
1423 let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
1424 edit_response_tx.send(empty_response()).unwrap();
1425 cx.run_until_parked();
1426
1427 let diagnostic = lsp::Diagnostic {
1428 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1429 severity: Some(lsp::DiagnosticSeverity::ERROR),
1430 message: "Sentence is incomplete".to_string(),
1431 ..Default::default()
1432 };
1433
1434 // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
1435 project.update(cx, |project, cx| {
1436 project.lsp_store().update(cx, |lsp_store, cx| {
1437 lsp_store
1438 .update_diagnostics(
1439 LanguageServerId(0),
1440 lsp::PublishDiagnosticsParams {
1441 uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
1442 diagnostics: vec![diagnostic],
1443 version: None,
1444 },
1445 None,
1446 language::DiagnosticSourceKind::Pushed,
1447 &[],
1448 cx,
1449 )
1450 .unwrap();
1451 });
1452 });
1453 let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
1454 jump_response_tx.send(empty_response()).unwrap();
1455 cx.run_until_parked();
1456
1457 // Second edit request - should be throttled by the first edit.
1458 ep_store.update(cx, |ep_store, cx| {
1459 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1460 });
1461 assert_no_predict_request_ready(&mut requests.predict);
1462
1463 // Second jump request - should be throttled by the first jump.
1464 ep_store.update(cx, |ep_store, cx| {
1465 ep_store.refresh_prediction_from_diagnostics(
1466 project.clone(),
1467 DiagnosticSearchScope::Global,
1468 cx,
1469 );
1470 });
1471 assert_no_predict_request_ready(&mut requests.predict);
1472
1473 // Wait for both throttles to expire.
1474 cx.background_executor
1475 .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
1476 cx.background_executor.run_until_parked();
1477 cx.run_until_parked();
1478
1479 // Both requests should now go through.
1480 let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
1481 response_tx_1.send(empty_response()).unwrap();
1482 cx.run_until_parked();
1483
1484 let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
1485 response_tx_2.send(empty_response()).unwrap();
1486 cx.run_until_parked();
1487}
1488
1489#[gpui::test]
1490async fn test_rejections_flushing(cx: &mut TestAppContext) {
1491 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1492
1493 ep_store.update(cx, |ep_store, cx| {
1494 ep_store.reject_prediction(
1495 EditPredictionId("test-1".into()),
1496 EditPredictionRejectReason::Discarded,
1497 false,
1498 None,
1499 cx,
1500 );
1501 ep_store.reject_prediction(
1502 EditPredictionId("test-2".into()),
1503 EditPredictionRejectReason::Canceled,
1504 true,
1505 None,
1506 cx,
1507 );
1508 });
1509
1510 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1511 cx.run_until_parked();
1512
1513 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1514 respond_tx.send(()).unwrap();
1515
1516 // batched
1517 assert_eq!(reject_request.rejections.len(), 2);
1518 assert_eq!(
1519 reject_request.rejections[0],
1520 EditPredictionRejection {
1521 request_id: "test-1".to_string(),
1522 reason: EditPredictionRejectReason::Discarded,
1523 was_shown: false,
1524 model_version: None,
1525 }
1526 );
1527 assert_eq!(
1528 reject_request.rejections[1],
1529 EditPredictionRejection {
1530 request_id: "test-2".to_string(),
1531 reason: EditPredictionRejectReason::Canceled,
1532 was_shown: true,
1533 model_version: None,
1534 }
1535 );
1536
1537 // Reaching batch size limit sends without debounce
1538 ep_store.update(cx, |ep_store, cx| {
1539 for i in 0..70 {
1540 ep_store.reject_prediction(
1541 EditPredictionId(format!("batch-{}", i).into()),
1542 EditPredictionRejectReason::Discarded,
1543 false,
1544 None,
1545 cx,
1546 );
1547 }
1548 });
1549
1550 // First MAX/2 items are sent immediately
1551 cx.run_until_parked();
1552 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1553 respond_tx.send(()).unwrap();
1554
1555 assert_eq!(reject_request.rejections.len(), 50);
1556 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1557 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1558
1559 // Remaining items are debounced with the next batch
1560 cx.executor().advance_clock(Duration::from_secs(15));
1561 cx.run_until_parked();
1562
1563 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1564 respond_tx.send(()).unwrap();
1565
1566 assert_eq!(reject_request.rejections.len(), 20);
1567 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1568 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1569
1570 // Request failure
1571 ep_store.update(cx, |ep_store, cx| {
1572 ep_store.reject_prediction(
1573 EditPredictionId("retry-1".into()),
1574 EditPredictionRejectReason::Discarded,
1575 false,
1576 None,
1577 cx,
1578 );
1579 });
1580
1581 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1582 cx.run_until_parked();
1583
1584 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1585 assert_eq!(reject_request.rejections.len(), 1);
1586 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1587 // Simulate failure
1588 drop(_respond_tx);
1589
1590 // Add another rejection
1591 ep_store.update(cx, |ep_store, cx| {
1592 ep_store.reject_prediction(
1593 EditPredictionId("retry-2".into()),
1594 EditPredictionRejectReason::Discarded,
1595 false,
1596 None,
1597 cx,
1598 );
1599 });
1600
1601 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1602 cx.run_until_parked();
1603
1604 // Retry should include both the failed item and the new one
1605 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1606 respond_tx.send(()).unwrap();
1607
1608 assert_eq!(reject_request.rejections.len(), 2);
1609 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1610 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1611}
1612
1613// Skipped until we start including diagnostics in prompt
1614// #[gpui::test]
1615// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1616// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1617// let fs = FakeFs::new(cx.executor());
1618// fs.insert_tree(
1619// "/root",
1620// json!({
1621// "foo.md": "Hello!\nBye"
1622// }),
1623// )
1624// .await;
1625// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1626
1627// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1628// let diagnostic = lsp::Diagnostic {
1629// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1630// severity: Some(lsp::DiagnosticSeverity::ERROR),
1631// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1632// ..Default::default()
1633// };
1634
1635// project.update(cx, |project, cx| {
1636// project.lsp_store().update(cx, |lsp_store, cx| {
1637// // Create some diagnostics
1638// lsp_store
1639// .update_diagnostics(
1640// LanguageServerId(0),
1641// lsp::PublishDiagnosticsParams {
1642// uri: path_to_buffer_uri.clone(),
1643// diagnostics: vec![diagnostic],
1644// version: None,
1645// },
1646// None,
1647// language::DiagnosticSourceKind::Pushed,
1648// &[],
1649// cx,
1650// )
1651// .unwrap();
1652// });
1653// });
1654
1655// let buffer = project
1656// .update(cx, |project, cx| {
1657// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1658// project.open_buffer(path, cx)
1659// })
1660// .await
1661// .unwrap();
1662
1663// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1664// let position = snapshot.anchor_before(language::Point::new(0, 0));
1665
1666// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1667// ep_store.request_prediction(&project, &buffer, position, cx)
1668// });
1669
1670// let (request, _respond_tx) = req_rx.next().await.unwrap();
1671
1672// assert_eq!(request.diagnostic_groups.len(), 1);
1673// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1674// .unwrap();
1675// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1676// assert_eq!(
1677// value,
1678// json!({
1679// "entries": [{
1680// "range": {
1681// "start": 8,
1682// "end": 10
1683// },
1684// "diagnostic": {
1685// "source": null,
1686// "code": null,
1687// "code_description": null,
1688// "severity": 1,
1689// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1690// "markdown": null,
1691// "group_id": 0,
1692// "is_primary": true,
1693// "is_disk_based": false,
1694// "is_unnecessary": false,
1695// "source_kind": "Pushed",
1696// "data": null,
1697// "underline": true
1698// }
1699// }],
1700// "primary_ix": 0
1701// })
1702// );
1703// }
1704
1705// Generate a model response that would apply the given diff to the active file.
1706fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1707 let editable_range =
1708 zeta_prompt::excerpt_range_for_format(Default::default(), &request.input.excerpt_ranges).1;
1709 let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string();
1710 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1711
1712 PredictEditsV3Response {
1713 request_id: Uuid::new_v4().to_string(),
1714 editable_range,
1715 output: new_excerpt,
1716 model_version: None,
1717 }
1718}
1719
1720fn empty_response() -> PredictEditsV3Response {
1721 PredictEditsV3Response {
1722 request_id: Uuid::new_v4().to_string(),
1723 editable_range: 0..0,
1724 output: String::new(),
1725 model_version: None,
1726 }
1727}
1728
1729fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1730 zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1731}
1732
1733fn assert_no_predict_request_ready(
1734 requests: &mut mpsc::UnboundedReceiver<(
1735 PredictEditsV3Request,
1736 oneshot::Sender<PredictEditsV3Response>,
1737 )>,
1738) {
1739 if requests.next().now_or_never().flatten().is_some() {
1740 panic!("Unexpected prediction request while throttled.");
1741 }
1742}
1743
1744struct RequestChannels {
1745 predict: mpsc::UnboundedReceiver<(
1746 PredictEditsV3Request,
1747 oneshot::Sender<PredictEditsV3Response>,
1748 )>,
1749 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1750}
1751
1752fn init_test_with_fake_client(
1753 cx: &mut TestAppContext,
1754) -> (Entity<EditPredictionStore>, RequestChannels) {
1755 cx.update(move |cx| {
1756 let settings_store = SettingsStore::test(cx);
1757 cx.set_global(settings_store);
1758 zlog::init_test();
1759
1760 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1761 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1762
1763 let http_client = FakeHttpClient::create({
1764 move |req| {
1765 let uri = req.uri().path().to_string();
1766 let mut body = req.into_body();
1767 let predict_req_tx = predict_req_tx.clone();
1768 let reject_req_tx = reject_req_tx.clone();
1769 async move {
1770 let resp = match uri.as_str() {
1771 "/client/llm_tokens" => serde_json::to_string(&json!({
1772 "token": "test"
1773 }))
1774 .unwrap(),
1775 "/predict_edits/v3" => {
1776 let mut buf = Vec::new();
1777 body.read_to_end(&mut buf).await.ok();
1778 let decompressed = zstd::decode_all(&buf[..]).unwrap();
1779 let req = serde_json::from_slice(&decompressed).unwrap();
1780
1781 let (res_tx, res_rx) = oneshot::channel();
1782 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1783 serde_json::to_string(&res_rx.await?).unwrap()
1784 }
1785 "/predict_edits/reject" => {
1786 let mut buf = Vec::new();
1787 body.read_to_end(&mut buf).await.ok();
1788 let req = serde_json::from_slice(&buf).unwrap();
1789
1790 let (res_tx, res_rx) = oneshot::channel();
1791 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1792 serde_json::to_string(&res_rx.await?).unwrap()
1793 }
1794 _ => {
1795 panic!("Unexpected path: {}", uri)
1796 }
1797 };
1798
1799 Ok(Response::builder().body(resp.into()).unwrap())
1800 }
1801 }
1802 });
1803
1804 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1805 client.cloud_client().set_credentials(1, "test".into());
1806
1807 language_model::init(client.clone(), cx);
1808
1809 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1810 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1811
1812 (
1813 ep_store,
1814 RequestChannels {
1815 predict: predict_req_rx,
1816 reject: reject_req_rx,
1817 },
1818 )
1819 })
1820}
1821
1822#[gpui::test]
1823async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1824 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1825 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1826 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1827 });
1828
1829 let edit_preview = cx
1830 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1831 .await;
1832
1833 let prediction = EditPrediction {
1834 edits,
1835 cursor_position: None,
1836 edit_preview,
1837 buffer: buffer.clone(),
1838 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1839 id: EditPredictionId("the-id".into()),
1840 inputs: ZetaPromptInput {
1841 events: Default::default(),
1842 related_files: Default::default(),
1843 cursor_path: Path::new("").into(),
1844 cursor_excerpt: "".into(),
1845 cursor_offset_in_excerpt: 0,
1846 excerpt_start_row: None,
1847 excerpt_ranges: Default::default(),
1848 experiment: None,
1849 in_open_source_repo: false,
1850 can_collect_data: false,
1851 repo_url: None,
1852 },
1853 buffer_snapshotted_at: Instant::now(),
1854 response_received_at: Instant::now(),
1855 model_version: None,
1856 };
1857
1858 cx.update(|cx| {
1859 assert_eq!(
1860 from_completion_edits(
1861 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1862 &buffer,
1863 cx
1864 ),
1865 vec![(2..5, "REM".into()), (9..11, "".into())]
1866 );
1867
1868 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1869 assert_eq!(
1870 from_completion_edits(
1871 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1872 &buffer,
1873 cx
1874 ),
1875 vec![(2..2, "REM".into()), (6..8, "".into())]
1876 );
1877
1878 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1879 assert_eq!(
1880 from_completion_edits(
1881 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1882 &buffer,
1883 cx
1884 ),
1885 vec![(2..5, "REM".into()), (9..11, "".into())]
1886 );
1887
1888 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1889 assert_eq!(
1890 from_completion_edits(
1891 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1892 &buffer,
1893 cx
1894 ),
1895 vec![(3..3, "EM".into()), (7..9, "".into())]
1896 );
1897
1898 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1899 assert_eq!(
1900 from_completion_edits(
1901 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1902 &buffer,
1903 cx
1904 ),
1905 vec![(4..4, "M".into()), (8..10, "".into())]
1906 );
1907
1908 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1909 assert_eq!(
1910 from_completion_edits(
1911 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1912 &buffer,
1913 cx
1914 ),
1915 vec![(9..11, "".into())]
1916 );
1917
1918 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1919 assert_eq!(
1920 from_completion_edits(
1921 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1922 &buffer,
1923 cx
1924 ),
1925 vec![(4..4, "M".into()), (8..10, "".into())]
1926 );
1927
1928 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1929 assert_eq!(
1930 from_completion_edits(
1931 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1932 &buffer,
1933 cx
1934 ),
1935 vec![(4..4, "M".into())]
1936 );
1937
1938 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1939 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1940 })
1941}
1942
1943#[gpui::test]
1944async fn test_clean_up_diff(cx: &mut TestAppContext) {
1945 init_test(cx);
1946
1947 assert_eq!(
1948 apply_edit_prediction(
1949 indoc! {"
1950 fn main() {
1951 let word_1 = \"lorem\";
1952 let range = word.len()..word.len();
1953 }
1954 "},
1955 indoc! {"
1956 fn main() {
1957 let word_1 = \"lorem\";
1958 let range = word_1.len()..word_1.len();
1959 }
1960 "},
1961 cx,
1962 )
1963 .await,
1964 indoc! {"
1965 fn main() {
1966 let word_1 = \"lorem\";
1967 let range = word_1.len()..word_1.len();
1968 }
1969 "},
1970 );
1971
1972 assert_eq!(
1973 apply_edit_prediction(
1974 indoc! {"
1975 fn main() {
1976 let story = \"the quick\"
1977 }
1978 "},
1979 indoc! {"
1980 fn main() {
1981 let story = \"the quick brown fox jumps over the lazy dog\";
1982 }
1983 "},
1984 cx,
1985 )
1986 .await,
1987 indoc! {"
1988 fn main() {
1989 let story = \"the quick brown fox jumps over the lazy dog\";
1990 }
1991 "},
1992 );
1993}
1994
1995#[gpui::test]
1996async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1997 init_test(cx);
1998
1999 let buffer_content = "lorem\n";
2000 let completion_response = "lorem\nipsum\n";
2001
2002 assert_eq!(
2003 apply_edit_prediction(buffer_content, completion_response, cx).await,
2004 "lorem\nipsum\n"
2005 );
2006}
2007
2008#[gpui::test]
2009async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
2010 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
2011 // When the buffer ends without a trailing newline, but the model returns output
2012 // with a trailing newline, zeta2 should normalize both sides before diffing
2013 // so no spurious newline is inserted.
2014 let (ep_store, mut requests) = init_test_with_fake_client(cx);
2015 let fs = FakeFs::new(cx.executor());
2016
2017 // Single line buffer with no trailing newline
2018 fs.insert_tree(
2019 "/root",
2020 json!({
2021 "foo.txt": "hello"
2022 }),
2023 )
2024 .await;
2025 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2026
2027 let buffer = project
2028 .update(cx, |project, cx| {
2029 let path = project
2030 .find_project_path(path!("root/foo.txt"), cx)
2031 .unwrap();
2032 project.open_buffer(path, cx)
2033 })
2034 .await
2035 .unwrap();
2036
2037 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2038 let position = snapshot.anchor_before(language::Point::new(0, 5));
2039
2040 ep_store.update(cx, |ep_store, cx| {
2041 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2042 });
2043
2044 let (request, respond_tx) = requests.predict.next().await.unwrap();
2045
2046 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
2047 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
2048 let excerpt_length = request.input.cursor_excerpt.len();
2049 let response = PredictEditsV3Response {
2050 request_id: Uuid::new_v4().to_string(),
2051 output: "hello world\n".to_string(),
2052 editable_range: 0..excerpt_length,
2053 model_version: None,
2054 };
2055 respond_tx.send(response).unwrap();
2056
2057 cx.run_until_parked();
2058
2059 // The prediction should insert " world" without adding a newline
2060 ep_store.update(cx, |ep_store, cx| {
2061 let prediction = ep_store
2062 .prediction_at(&buffer, None, &project, cx)
2063 .expect("should have prediction");
2064 let edits: Vec<_> = prediction
2065 .edits
2066 .iter()
2067 .map(|(range, text)| {
2068 let snapshot = buffer.read(cx).snapshot();
2069 (range.to_offset(&snapshot), text.clone())
2070 })
2071 .collect();
2072 assert_eq!(edits, vec![(5..5, " world".into())]);
2073 });
2074}
2075
2076fn init_test(cx: &mut TestAppContext) {
2077 cx.update(|cx| {
2078 let settings_store = SettingsStore::test(cx);
2079 cx.set_global(settings_store);
2080 });
2081}
2082
2083async fn apply_edit_prediction(
2084 buffer_content: &str,
2085 completion_response: &str,
2086 cx: &mut TestAppContext,
2087) -> String {
2088 let fs = project::FakeFs::new(cx.executor());
2089 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2090 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2091 let (ep_store, response) = make_test_ep_store(&project, cx).await;
2092 *response.lock() = completion_response.to_string();
2093 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2094 buffer.update(cx, |buffer, cx| {
2095 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2096 });
2097 buffer.read_with(cx, |buffer, _| buffer.text())
2098}
2099
2100async fn run_edit_prediction(
2101 buffer: &Entity<Buffer>,
2102 project: &Entity<Project>,
2103 ep_store: &Entity<EditPredictionStore>,
2104 cx: &mut TestAppContext,
2105) -> EditPrediction {
2106 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2107 ep_store.update(cx, |ep_store, cx| {
2108 ep_store.register_buffer(buffer, &project, cx)
2109 });
2110 cx.background_executor.run_until_parked();
2111 let prediction_task = ep_store.update(cx, |ep_store, cx| {
2112 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2113 });
2114 prediction_task.await.unwrap().unwrap().prediction.unwrap()
2115}
2116
2117async fn make_test_ep_store(
2118 project: &Entity<Project>,
2119 cx: &mut TestAppContext,
2120) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
2121 let default_response = "hello world\n".to_string();
2122 let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
2123 let http_client = FakeHttpClient::create({
2124 let completion_response = completion_response.clone();
2125 let mut next_request_id = 0;
2126 move |req| {
2127 let completion_response = completion_response.clone();
2128 let method = req.method().clone();
2129 let uri = req.uri().path().to_string();
2130 let mut body = req.into_body();
2131 async move {
2132 match (method, uri.as_str()) {
2133 (Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2134 .status(200)
2135 .body(
2136 serde_json::to_string(&CreateLlmTokenResponse {
2137 token: LlmToken("the-llm-token".to_string()),
2138 })
2139 .unwrap()
2140 .into(),
2141 )
2142 .unwrap()),
2143 (Method::POST, "/predict_edits/v3") => {
2144 let mut buf = Vec::new();
2145 body.read_to_end(&mut buf).await.ok();
2146 let decompressed = zstd::decode_all(&buf[..]).unwrap();
2147 let req: PredictEditsV3Request =
2148 serde_json::from_slice(&decompressed).unwrap();
2149
2150 next_request_id += 1;
2151 Ok(http_client::Response::builder()
2152 .status(200)
2153 .body(
2154 serde_json::to_string(&PredictEditsV3Response {
2155 request_id: format!("request-{next_request_id}"),
2156 editable_range: 0..req.input.cursor_excerpt.len(),
2157 output: completion_response.lock().clone(),
2158 model_version: None,
2159 })
2160 .unwrap()
2161 .into(),
2162 )
2163 .unwrap())
2164 }
2165 _ => Ok(http_client::Response::builder()
2166 .status(404)
2167 .body("Not Found".to_string().into())
2168 .unwrap()),
2169 }
2170 }
2171 }
2172 });
2173
2174 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2175 cx.update(|cx| {
2176 RefreshLlmTokenListener::register(client.clone(), cx);
2177 });
2178 let _server = FakeServer::for_client(42, &client, cx).await;
2179
2180 let ep_store = cx.new(|cx| {
2181 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2182 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2183
2184 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2185 for worktree in worktrees {
2186 let worktree_id = worktree.read(cx).id();
2187 ep_store
2188 .get_or_init_project(project, cx)
2189 .license_detection_watchers
2190 .entry(worktree_id)
2191 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2192 }
2193
2194 ep_store
2195 });
2196
2197 (ep_store, completion_response)
2198}
2199
2200fn to_completion_edits(
2201 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2202 buffer: &Entity<Buffer>,
2203 cx: &App,
2204) -> Vec<(Range<Anchor>, Arc<str>)> {
2205 let buffer = buffer.read(cx);
2206 iterator
2207 .into_iter()
2208 .map(|(range, text)| {
2209 (
2210 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2211 text,
2212 )
2213 })
2214 .collect()
2215}
2216
2217fn from_completion_edits(
2218 editor_edits: &[(Range<Anchor>, Arc<str>)],
2219 buffer: &Entity<Buffer>,
2220 cx: &App,
2221) -> Vec<(Range<usize>, Arc<str>)> {
2222 let buffer = buffer.read(cx);
2223 editor_edits
2224 .iter()
2225 .map(|(range, text)| {
2226 (
2227 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2228 text.clone(),
2229 )
2230 })
2231 .collect()
2232}
2233
2234#[gpui::test]
2235async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2236 init_test(cx);
2237
2238 let fs = FakeFs::new(cx.executor());
2239 fs.insert_tree(
2240 "/project",
2241 serde_json::json!({
2242 "main.rs": "fn main() {\n \n}\n"
2243 }),
2244 )
2245 .await;
2246
2247 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2248
2249 let http_client = FakeHttpClient::create(|_req| async move {
2250 Ok(gpui::http_client::Response::builder()
2251 .status(401)
2252 .body("Unauthorized".into())
2253 .unwrap())
2254 });
2255
2256 let client =
2257 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2258 cx.update(|cx| {
2259 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2260 });
2261
2262 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2263
2264 let buffer = project
2265 .update(cx, |project, cx| {
2266 let path = project
2267 .find_project_path(path!("/project/main.rs"), cx)
2268 .unwrap();
2269 project.open_buffer(path, cx)
2270 })
2271 .await
2272 .unwrap();
2273
2274 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2275 ep_store.update(cx, |ep_store, cx| {
2276 ep_store.register_buffer(&buffer, &project, cx)
2277 });
2278 cx.background_executor.run_until_parked();
2279
2280 let completion_task = ep_store.update(cx, |ep_store, cx| {
2281 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2282 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2283 });
2284
2285 let result = completion_task.await;
2286 assert!(
2287 result.is_err(),
2288 "Without authentication and without custom URL, prediction should fail"
2289 );
2290}
2291
2292#[gpui::test]
2293fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2294 let buffer = cx.new(|cx| {
2295 Buffer::local(
2296 indoc! {"
2297 zero
2298 one
2299 two
2300 three
2301 four
2302 five
2303 six
2304 seven
2305 eight
2306 nine
2307 ten
2308 eleven
2309 twelve
2310 thirteen
2311 fourteen
2312 fifteen
2313 sixteen
2314 seventeen
2315 eighteen
2316 nineteen
2317 twenty
2318 twenty-one
2319 twenty-two
2320 twenty-three
2321 twenty-four
2322 "},
2323 cx,
2324 )
2325 });
2326
2327 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2328
2329 buffer.update(cx, |buffer, cx| {
2330 let point = Point::new(12, 0);
2331 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2332 let point = Point::new(8, 0);
2333 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2334 });
2335
2336 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2337
2338 let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2339
2340 assert_eq!(
2341 diff,
2342 indoc! {"
2343 @@ -6,10 +6,12 @@
2344 five
2345 six
2346 seven
2347 +FIRST INSERTION
2348 eight
2349 nine
2350 ten
2351 eleven
2352 +SECOND INSERTION
2353 twelve
2354 thirteen
2355 fourteen
2356 "}
2357 );
2358}
2359
2360#[gpui::test]
2361async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
2362 fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
2363 let collab_replica = clock::ReplicaId::new(10);
2364 let anchor = buffer.read_with(cx, |buffer, _| {
2365 buffer.snapshot().anchor_before(Point::new(row, 0))
2366 });
2367 let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
2368 id: 1,
2369 start: anchor,
2370 end: anchor,
2371 reversed: false,
2372 goal: SelectionGoal::None,
2373 }]);
2374 buffer.update(cx, |buffer, cx| {
2375 buffer.apply_ops(
2376 [Operation::UpdateSelections {
2377 selections,
2378 lamport_timestamp: clock::Lamport {
2379 replica_id: collab_replica,
2380 value: 1,
2381 },
2382 line_mode: false,
2383 cursor_shape: CursorShape::Bar,
2384 }],
2385 cx,
2386 );
2387 });
2388 }
2389
2390 fn publish_diagnostics(
2391 uri_path: &'static str,
2392 rows: &[u32],
2393 project: &Entity<Project>,
2394 cx: &mut TestAppContext,
2395 ) {
2396 let diagnostics: Vec<_> = rows
2397 .iter()
2398 .map(|&row| lsp::Diagnostic {
2399 range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
2400 severity: Some(lsp::DiagnosticSeverity::ERROR),
2401 message: format!("error at row {row}"),
2402 ..Default::default()
2403 })
2404 .collect();
2405 project.update(cx, |project, cx| {
2406 project.lsp_store().update(cx, |lsp_store, cx| {
2407 lsp_store
2408 .update_diagnostics(
2409 LanguageServerId(0),
2410 lsp::PublishDiagnosticsParams {
2411 uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
2412 diagnostics,
2413 version: None,
2414 },
2415 None,
2416 language::DiagnosticSourceKind::Pushed,
2417 &[],
2418 cx,
2419 )
2420 .expect("failed to update diagnostics");
2421 });
2422 });
2423 }
2424
2425 init_test(cx);
2426
2427 let mut lines = String::new();
2428 for i in 0..60 {
2429 lines.push_str(&format!("line {i}\n"));
2430 }
2431
2432 let fs = FakeFs::new(cx.executor());
2433 fs.insert_tree(
2434 "/root",
2435 json!({
2436 "active.txt": lines,
2437 "collab_file.txt": "error here\nsecond line\n",
2438 "free_file.txt": "another error\nsecond line\n",
2439 }),
2440 )
2441 .await;
2442 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2443
2444 let active_buffer = project
2445 .update(cx, |project, cx| {
2446 let path = project
2447 .find_project_path(path!("/root/active.txt"), cx)
2448 .expect("active.txt not found");
2449 project.set_active_path(Some(path.clone()), cx);
2450 project.open_buffer(path, cx)
2451 })
2452 .await
2453 .expect("failed to open active buffer");
2454
2455 set_collaborator_cursor(&active_buffer, 5, cx);
2456
2457 publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
2458
2459 cx.run_until_parked();
2460
2461 let cursor_point = Point::new(25, 0);
2462 let empty_search_range: Range<Point> = Default::default();
2463
2464 let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2465 let result = EditPredictionStore::next_diagnostic_location(
2466 active_buffer.clone(),
2467 &snapshot,
2468 empty_search_range.clone(),
2469 cursor_point,
2470 &project,
2471 &mut cx.to_async(),
2472 )
2473 .await
2474 .expect("next_diagnostic_location failed");
2475
2476 let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
2477 assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
2478 let result_row = result_buffer.read_with(cx, |buffer, _| {
2479 result_anchor.to_point(&buffer.snapshot()).row
2480 });
2481 assert_ne!(
2482 result_row, 3,
2483 "row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
2484 );
2485 assert!(
2486 result_row == 25 || result_row == 50,
2487 "expected row 25 or 50, got {result_row}"
2488 );
2489
2490 let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2491 let near_cursor_point = Point::new(4, 0);
2492 let result_near = EditPredictionStore::next_diagnostic_location(
2493 active_buffer.clone(),
2494 &snapshot_near,
2495 empty_search_range.clone(),
2496 near_cursor_point,
2497 &project,
2498 &mut cx.to_async(),
2499 )
2500 .await
2501 .expect("next_diagnostic_location failed");
2502
2503 let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
2504 let near_row =
2505 active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
2506 assert_eq!(
2507 near_row, 3,
2508 "row 3 should be included when local cursor (row 4) is also near the collaborator"
2509 );
2510
2511 let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2512 let far_cursor_point = Point::new(50, 0);
2513 let result_far = EditPredictionStore::next_diagnostic_location(
2514 active_buffer.clone(),
2515 &snapshot_far,
2516 empty_search_range.clone(),
2517 far_cursor_point,
2518 &project,
2519 &mut cx.to_async(),
2520 )
2521 .await
2522 .expect("next_diagnostic_location failed");
2523
2524 let (_, far_anchor) = result_far.expect("expected a diagnostic location");
2525 let far_row =
2526 active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
2527 assert_eq!(
2528 far_row, 50,
2529 "row 50 is near local cursor (row 50) and far from collaborator, should be picked"
2530 );
2531
2532 publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
2533 publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
2534 cx.run_until_parked();
2535
2536 let collab_buffer = project
2537 .update(cx, |project, cx| {
2538 let path = project
2539 .find_project_path(path!("/root/collab_file.txt"), cx)
2540 .expect("collab_file.txt not found");
2541 project.open_buffer(path, cx)
2542 })
2543 .await
2544 .expect("failed to open collab buffer");
2545
2546 set_collaborator_cursor(&collab_buffer, 0, cx);
2547 cx.run_until_parked();
2548
2549 let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
2550 let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2551 let result_cross = EditPredictionStore::next_diagnostic_location(
2552 active_buffer.clone(),
2553 &snapshot_cross,
2554 no_same_file_search_range,
2555 Point::new(0, 0),
2556 &project,
2557 &mut cx.to_async(),
2558 )
2559 .await
2560 .expect("cross-file next_diagnostic_location failed");
2561
2562 let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
2563 let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
2564 buffer
2565 .file()
2566 .expect("buffer should have a file")
2567 .full_path(cx)
2568 });
2569 assert_eq!(
2570 cross_path,
2571 Path::new(path!("root/free_file.txt")),
2572 "should skip collab_file.txt (has collaborator) and pick free_file.txt"
2573 );
2574}
2575
2576#[gpui::test]
2577async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
2578 let (ep_store, _requests) = init_test_with_fake_client(cx);
2579 let fs = FakeFs::new(cx.executor());
2580
2581 // Buffer with two clearly separated regions:
2582 // Region A = lines 0-9 (offsets 0..50)
2583 // Region B = lines 20-29 (offsets 105..155)
2584 // A big gap in between so edits in one region never overlap the other.
2585 let mut content = String::new();
2586 for i in 0..30 {
2587 content.push_str(&format!("line {i:02}\n"));
2588 }
2589
2590 fs.insert_tree(
2591 "/root",
2592 json!({
2593 "foo.md": content.clone()
2594 }),
2595 )
2596 .await;
2597 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2598
2599 let buffer = project
2600 .update(cx, |project, cx| {
2601 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2602 project.open_buffer(path, cx)
2603 })
2604 .await
2605 .unwrap();
2606
2607 let settled_events: Arc<Mutex<Vec<(EditPredictionId, String)>>> =
2608 Arc::new(Mutex::new(Vec::new()));
2609
2610 ep_store.update(cx, |ep_store, cx| {
2611 ep_store.register_buffer(&buffer, &project, cx);
2612
2613 let settled_events = settled_events.clone();
2614 ep_store.settled_event_callback = Some(Box::new(move |id, text| {
2615 settled_events.lock().push((id, text));
2616 }));
2617 });
2618
2619 // --- Phase 1: edit in region A and enqueue prediction A ---
2620
2621 buffer.update(cx, |buffer, cx| {
2622 // Edit at the start of line 0.
2623 buffer.edit(vec![(0..0, "ADDED ")], None, cx);
2624 });
2625 cx.run_until_parked();
2626
2627 let snapshot_a = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2628
2629 // Region A: first 10 lines of the buffer.
2630 let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0));
2631 ep_store.update(cx, |ep_store, cx| {
2632 ep_store.enqueue_settled_prediction(
2633 EditPredictionId("prediction-a".into()),
2634 &project,
2635 &buffer,
2636 &snapshot_a,
2637 editable_region_a,
2638 cx,
2639 );
2640 });
2641
2642 // --- Phase 2: repeatedly edit in region A to keep it unsettled ---
2643
2644 // Let the worker process the channel message before we start advancing.
2645 cx.run_until_parked();
2646
2647 let mut region_a_edit_offset = 5;
2648 for _ in 0..3 {
2649 // Edit inside region A (not at the boundary) so `last_edit_at` is
2650 // updated before the worker's next wake.
2651 buffer.update(cx, |buffer, cx| {
2652 buffer.edit(
2653 vec![(region_a_edit_offset..region_a_edit_offset, "x")],
2654 None,
2655 cx,
2656 );
2657 });
2658 region_a_edit_offset += 1;
2659 cx.run_until_parked();
2660
2661 cx.executor()
2662 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 2);
2663 cx.run_until_parked();
2664 assert!(
2665 settled_events.lock().is_empty(),
2666 "no settled events should fire while region A is still being edited"
2667 );
2668 }
2669
2670 // Still nothing settled.
2671 assert!(settled_events.lock().is_empty());
2672
2673 // --- Phase 3: edit in distinct region B, enqueue prediction B ---
2674 // Advance a small amount so B's quiescence window starts later than A's,
2675 // but not so much that A settles (A's last edit was at the start of
2676 // iteration 3, and it needs a full Q to settle).
2677 cx.executor()
2678 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2679 cx.run_until_parked();
2680 assert!(settled_events.lock().is_empty());
2681
2682 let snapshot_b = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2683 let line_20_offset = snapshot_b.point_to_offset(Point::new(20, 0));
2684
2685 buffer.update(cx, |buffer, cx| {
2686 buffer.edit(vec![(line_20_offset..line_20_offset, "NEW ")], None, cx);
2687 });
2688 cx.run_until_parked();
2689
2690 let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2691 let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0));
2692 ep_store.update(cx, |ep_store, cx| {
2693 ep_store.enqueue_settled_prediction(
2694 EditPredictionId("prediction-b".into()),
2695 &project,
2696 &buffer,
2697 &snapshot_b2,
2698 editable_region_b,
2699 cx,
2700 );
2701 });
2702
2703 cx.run_until_parked();
2704 assert!(
2705 settled_events.lock().is_empty(),
2706 "neither prediction should have settled yet"
2707 );
2708
2709 // --- Phase 4: let enough time pass for region A to settle ---
2710 // A's last edit was at T_a (during the last loop iteration). The worker is
2711 // sleeping until T_a + Q. We advance just enough to reach that wake time
2712 // (Q/4 since we already advanced Q/4 in phase 3 on top of the loop's
2713 // 3*Q/2). At that point A has been quiet for Q and settles, but B was
2714 // enqueued only Q/4 ago and stays pending.
2715 cx.executor()
2716 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2717 cx.run_until_parked();
2718
2719 {
2720 let events = settled_events.lock().clone();
2721 assert_eq!(
2722 events.len(),
2723 1,
2724 "only prediction A should have settled, got: {events:?}"
2725 );
2726 assert_eq!(events[0].0, EditPredictionId("prediction-a".into()));
2727 }
2728
2729 // --- Phase 5: let more time pass for region B to settle ---
2730 // B's last edit was Q/4 before A settled. The worker rescheduled to
2731 // B's last_edit_at + Q, which is 3Q/4 from now.
2732 cx.executor()
2733 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE * 3 / 4);
2734 cx.run_until_parked();
2735
2736 {
2737 let events = settled_events.lock().clone();
2738 assert_eq!(
2739 events.len(),
2740 2,
2741 "both predictions should have settled, got: {events:?}"
2742 );
2743 assert_eq!(events[1].0, EditPredictionId("prediction-b".into()));
2744 }
2745}
2746
2747#[ctor::ctor]
2748fn init_logger() {
2749 zlog::init_test();
2750}