1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
3use client::{UserStore, test::FakeServer};
4use clock::FakeSystemClock;
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
8 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
9};
10use futures::{
11 AsyncReadExt, FutureExt, StreamExt,
12 channel::{mpsc, oneshot},
13};
14use gpui::App;
15use gpui::{
16 Entity, TestAppContext,
17 http_client::{FakeHttpClient, Response},
18};
19use indoc::indoc;
20use language::{Anchor, Buffer, CursorShape, Operation, Point, Selection, SelectionGoal};
21use lsp::LanguageServerId;
22use parking_lot::Mutex;
23use pretty_assertions::{assert_eq, assert_matches};
24use project::{FakeFs, Project};
25use serde_json::json;
26use settings::SettingsStore;
27use std::{path::Path, sync::Arc, time::Duration};
28use util::path;
29use uuid::Uuid;
30use zeta_prompt::ZetaPromptInput;
31
32use crate::{
33 BufferEditPrediction, EDIT_PREDICTION_SETTLED_QUIESCENCE, EditPredictionId,
34 EditPredictionStore, REJECT_REQUEST_DEBOUNCE,
35};
36
37#[gpui::test]
38async fn test_current_state(cx: &mut TestAppContext) {
39 let (ep_store, mut requests) = init_test_with_fake_client(cx);
40 let fs = FakeFs::new(cx.executor());
41 fs.insert_tree(
42 "/root",
43 json!({
44 "1.txt": "Hello!\nHow\nBye\n",
45 "2.txt": "Hola!\nComo\nAdios\n"
46 }),
47 )
48 .await;
49 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
50
51 let buffer1 = project
52 .update(cx, |project, cx| {
53 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
54 project.set_active_path(Some(path.clone()), cx);
55 project.open_buffer(path, cx)
56 })
57 .await
58 .unwrap();
59 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
60 let position = snapshot1.anchor_before(language::Point::new(1, 3));
61
62 ep_store.update(cx, |ep_store, cx| {
63 ep_store.register_project(&project, cx);
64 ep_store.register_buffer(&buffer1, &project, cx);
65 });
66
67 // Prediction for current file
68
69 ep_store.update(cx, |ep_store, cx| {
70 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
71 });
72 let (request, respond_tx) = requests.predict.next().await.unwrap();
73
74 respond_tx
75 .send(model_response(
76 &request,
77 indoc! {r"
78 --- a/root/1.txt
79 +++ b/root/1.txt
80 @@ ... @@
81 Hello!
82 -How
83 +How are you?
84 Bye
85 "},
86 ))
87 .unwrap();
88
89 cx.run_until_parked();
90
91 ep_store.update(cx, |ep_store, cx| {
92 let prediction = ep_store
93 .prediction_at(&buffer1, None, &project, cx)
94 .unwrap();
95 assert_matches!(prediction, BufferEditPrediction::Local { .. });
96 });
97
98 ep_store.update(cx, |ep_store, cx| {
99 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
100 });
101
102 // Prediction for diagnostic in another file
103
104 let diagnostic = lsp::Diagnostic {
105 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
106 severity: Some(lsp::DiagnosticSeverity::ERROR),
107 message: "Sentence is incomplete".to_string(),
108 ..Default::default()
109 };
110
111 project.update(cx, |project, cx| {
112 project.lsp_store().update(cx, |lsp_store, cx| {
113 lsp_store
114 .update_diagnostics(
115 LanguageServerId(0),
116 lsp::PublishDiagnosticsParams {
117 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
118 diagnostics: vec![diagnostic],
119 version: None,
120 },
121 None,
122 language::DiagnosticSourceKind::Pushed,
123 &[],
124 cx,
125 )
126 .unwrap();
127 });
128 });
129
130 let (request, respond_tx) = requests.predict.next().await.unwrap();
131 respond_tx
132 .send(model_response(
133 &request,
134 indoc! {r#"
135 --- a/root/2.txt
136 +++ b/root/2.txt
137 @@ ... @@
138 Hola!
139 -Como
140 +Como estas?
141 Adios
142 "#},
143 ))
144 .unwrap();
145 cx.run_until_parked();
146
147 ep_store.update(cx, |ep_store, cx| {
148 let prediction = ep_store
149 .prediction_at(&buffer1, None, &project, cx)
150 .unwrap();
151 assert_matches!(
152 prediction,
153 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
154 );
155 });
156
157 let buffer2 = project
158 .update(cx, |project, cx| {
159 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
160 project.open_buffer(path, cx)
161 })
162 .await
163 .unwrap();
164
165 ep_store.update(cx, |ep_store, cx| {
166 let prediction = ep_store
167 .prediction_at(&buffer2, None, &project, cx)
168 .unwrap();
169 assert_matches!(prediction, BufferEditPrediction::Local { .. });
170 });
171}
172
173#[gpui::test]
174async fn test_simple_request(cx: &mut TestAppContext) {
175 let (ep_store, mut requests) = init_test_with_fake_client(cx);
176 let fs = FakeFs::new(cx.executor());
177 fs.insert_tree(
178 "/root",
179 json!({
180 "foo.md": "Hello!\nHow\nBye\n"
181 }),
182 )
183 .await;
184 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
185
186 let buffer = project
187 .update(cx, |project, cx| {
188 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
189 project.open_buffer(path, cx)
190 })
191 .await
192 .unwrap();
193 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
194 let position = snapshot.anchor_before(language::Point::new(1, 3));
195
196 let prediction_task = ep_store.update(cx, |ep_store, cx| {
197 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
198 });
199
200 let (request, respond_tx) = requests.predict.next().await.unwrap();
201
202 // TODO Put back when we have a structured request again
203 // assert_eq!(
204 // request.excerpt_path.as_ref(),
205 // Path::new(path!("root/foo.md"))
206 // );
207 // assert_eq!(
208 // request.cursor_point,
209 // Point {
210 // line: Line(1),
211 // column: 3
212 // }
213 // );
214
215 respond_tx
216 .send(model_response(
217 &request,
218 indoc! { r"
219 --- a/root/foo.md
220 +++ b/root/foo.md
221 @@ ... @@
222 Hello!
223 -How
224 +How are you?
225 Bye
226 "},
227 ))
228 .unwrap();
229
230 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
231
232 assert_eq!(prediction.edits.len(), 1);
233 assert_eq!(
234 prediction.edits[0].0.to_point(&snapshot).start,
235 language::Point::new(1, 3)
236 );
237 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
238}
239
240#[gpui::test]
241async fn test_request_events(cx: &mut TestAppContext) {
242 let (ep_store, mut requests) = init_test_with_fake_client(cx);
243 let fs = FakeFs::new(cx.executor());
244 fs.insert_tree(
245 "/root",
246 json!({
247 "foo.md": "Hello!\n\nBye\n"
248 }),
249 )
250 .await;
251 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
252
253 let buffer = project
254 .update(cx, |project, cx| {
255 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
256 project.open_buffer(path, cx)
257 })
258 .await
259 .unwrap();
260
261 ep_store.update(cx, |ep_store, cx| {
262 ep_store.register_buffer(&buffer, &project, cx);
263 });
264
265 buffer.update(cx, |buffer, cx| {
266 buffer.edit(vec![(7..7, "How")], None, cx);
267 });
268
269 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
270 let position = snapshot.anchor_before(language::Point::new(1, 3));
271
272 let prediction_task = ep_store.update(cx, |ep_store, cx| {
273 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
274 });
275
276 let (request, respond_tx) = requests.predict.next().await.unwrap();
277
278 let prompt = prompt_from_request(&request);
279 assert!(
280 prompt.contains(indoc! {"
281 --- a/root/foo.md
282 +++ b/root/foo.md
283 @@ -1,3 +1,3 @@
284 Hello!
285 -
286 +How
287 Bye
288 "}),
289 "{prompt}"
290 );
291
292 respond_tx
293 .send(model_response(
294 &request,
295 indoc! {r#"
296 --- a/root/foo.md
297 +++ b/root/foo.md
298 @@ ... @@
299 Hello!
300 -How
301 +How are you?
302 Bye
303 "#},
304 ))
305 .unwrap();
306
307 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
308
309 assert_eq!(prediction.edits.len(), 1);
310 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
311}
312
313#[gpui::test]
314async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
315 let (ep_store, _requests) = init_test_with_fake_client(cx);
316 let fs = FakeFs::new(cx.executor());
317 fs.insert_tree(
318 "/root",
319 json!({
320 "foo.md": "Hello!\n\nBye\n"
321 }),
322 )
323 .await;
324 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
325
326 let buffer = project
327 .update(cx, |project, cx| {
328 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
329 project.open_buffer(path, cx)
330 })
331 .await
332 .unwrap();
333
334 ep_store.update(cx, |ep_store, cx| {
335 ep_store.register_buffer(&buffer, &project, cx);
336 });
337
338 // First burst: insert "How"
339 buffer.update(cx, |buffer, cx| {
340 buffer.edit(vec![(7..7, "How")], None, cx);
341 });
342
343 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
344 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
345 cx.run_until_parked();
346
347 // Second burst: append " are you?" immediately after "How" on the same line.
348 //
349 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
350 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
351 buffer.update(cx, |buffer, cx| {
352 buffer.edit(vec![(10..10, " are you?")], None, cx);
353 });
354
355 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
356 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
357 buffer.update(cx, |buffer, cx| {
358 buffer.edit(vec![(19..19, "!")], None, cx);
359 });
360
361 // With time-based splitting, there are two distinct events.
362 let events = ep_store.update(cx, |ep_store, cx| {
363 ep_store.edit_history_for_project(&project, cx)
364 });
365 assert_eq!(events.len(), 2);
366 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
367 assert_eq!(
368 diff.as_str(),
369 indoc! {"
370 @@ -1,3 +1,3 @@
371 Hello!
372 -
373 +How
374 Bye
375 "}
376 );
377
378 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
379 assert_eq!(
380 diff.as_str(),
381 indoc! {"
382 @@ -1,3 +1,3 @@
383 Hello!
384 -How
385 +How are you?!
386 Bye
387 "}
388 );
389}
390
391#[gpui::test]
392async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
393 let (ep_store, _requests) = init_test_with_fake_client(cx);
394 let fs = FakeFs::new(cx.executor());
395
396 // Create a file with 30 lines to test line-based coalescing
397 let content = (1..=30)
398 .map(|i| format!("Line {}\n", i))
399 .collect::<String>();
400 fs.insert_tree(
401 "/root",
402 json!({
403 "foo.md": content
404 }),
405 )
406 .await;
407 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
408
409 let buffer = project
410 .update(cx, |project, cx| {
411 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
412 project.open_buffer(path, cx)
413 })
414 .await
415 .unwrap();
416
417 ep_store.update(cx, |ep_store, cx| {
418 ep_store.register_buffer(&buffer, &project, cx);
419 });
420
421 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
422 buffer.update(cx, |buffer, cx| {
423 let start = Point::new(10, 0).to_offset(buffer);
424 let end = Point::new(13, 0).to_offset(buffer);
425 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
426 });
427
428 let events = ep_store.update(cx, |ep_store, cx| {
429 ep_store.edit_history_for_project(&project, cx)
430 });
431 assert_eq!(
432 render_events(&events),
433 indoc! {"
434 @@ -8,9 +8,8 @@
435 Line 8
436 Line 9
437 Line 10
438 -Line 11
439 -Line 12
440 -Line 13
441 +Middle A
442 +Middle B
443 Line 14
444 Line 15
445 Line 16
446 "},
447 "After first edit"
448 );
449
450 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
451 // This tests that coalescing considers the START of the existing range
452 buffer.update(cx, |buffer, cx| {
453 let offset = Point::new(5, 0).to_offset(buffer);
454 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
455 });
456
457 let events = ep_store.update(cx, |ep_store, cx| {
458 ep_store.edit_history_for_project(&project, cx)
459 });
460 assert_eq!(
461 render_events(&events),
462 indoc! {"
463 @@ -3,14 +3,14 @@
464 Line 3
465 Line 4
466 Line 5
467 +Above
468 Line 6
469 Line 7
470 Line 8
471 Line 9
472 Line 10
473 -Line 11
474 -Line 12
475 -Line 13
476 +Middle A
477 +Middle B
478 Line 14
479 Line 15
480 Line 16
481 "},
482 "After inserting above (should coalesce)"
483 );
484
485 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
486 // This tests that coalescing considers the END of the existing range
487 buffer.update(cx, |buffer, cx| {
488 let offset = Point::new(14, 0).to_offset(buffer);
489 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
490 });
491
492 let events = ep_store.update(cx, |ep_store, cx| {
493 ep_store.edit_history_for_project(&project, cx)
494 });
495 assert_eq!(
496 render_events(&events),
497 indoc! {"
498 @@ -3,15 +3,16 @@
499 Line 3
500 Line 4
501 Line 5
502 +Above
503 Line 6
504 Line 7
505 Line 8
506 Line 9
507 Line 10
508 -Line 11
509 -Line 12
510 -Line 13
511 +Middle A
512 +Middle B
513 Line 14
514 +Below
515 Line 15
516 Line 16
517 Line 17
518 "},
519 "After inserting below (should coalesce)"
520 );
521
522 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
523 // This should NOT coalesce - creates a new event
524 buffer.update(cx, |buffer, cx| {
525 let offset = Point::new(25, 0).to_offset(buffer);
526 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
527 });
528
529 let events = ep_store.update(cx, |ep_store, cx| {
530 ep_store.edit_history_for_project(&project, cx)
531 });
532 assert_eq!(
533 render_events(&events),
534 indoc! {"
535 @@ -3,15 +3,16 @@
536 Line 3
537 Line 4
538 Line 5
539 +Above
540 Line 6
541 Line 7
542 Line 8
543 Line 9
544 Line 10
545 -Line 11
546 -Line 12
547 -Line 13
548 +Middle A
549 +Middle B
550 Line 14
551 +Below
552 Line 15
553 Line 16
554 Line 17
555
556 ---
557 @@ -23,6 +23,7 @@
558 Line 22
559 Line 23
560 Line 24
561 +Far below
562 Line 25
563 Line 26
564 Line 27
565 "},
566 "After inserting far below (should NOT coalesce)"
567 );
568}
569
570fn render_events(events: &[StoredEvent]) -> String {
571 events
572 .iter()
573 .map(|e| {
574 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
575 diff.as_str()
576 })
577 .collect::<Vec<_>>()
578 .join("\n---\n")
579}
580
581fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
582 events
583 .iter()
584 .map(|e| {
585 let zeta_prompt::Event::BufferChange {
586 diff, predicted, ..
587 } = e.event.as_ref();
588 let prefix = if *predicted { "predicted" } else { "manual" };
589 format!("{}\n{}", prefix, diff)
590 })
591 .collect()
592}
593
594#[gpui::test]
595async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
596 let (ep_store, _requests) = init_test_with_fake_client(cx);
597 let fs = FakeFs::new(cx.executor());
598 fs.insert_tree(
599 "/root",
600 json!({
601 "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
602 }),
603 )
604 .await;
605 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
606
607 let buffer = project
608 .update(cx, |project, cx| {
609 let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
610 project.open_buffer(path, cx)
611 })
612 .await
613 .unwrap();
614
615 ep_store.update(cx, |ep_store, cx| {
616 ep_store.register_buffer(&buffer, &project, cx);
617 });
618
619 // Case 1: Manual edits have `predicted` set to false.
620 buffer.update(cx, |buffer, cx| {
621 buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
622 });
623
624 let events = ep_store.update(cx, |ep_store, cx| {
625 ep_store.edit_history_for_project(&project, cx)
626 });
627
628 assert_eq!(
629 render_events_with_predicted(&events),
630 vec![indoc! {"
631 manual
632 @@ -1,4 +1,4 @@
633 -line 0
634 +LINE ZERO
635 line 1
636 line 2
637 line 3
638 "}]
639 );
640
641 // Case 2: Multiple successive manual edits near each other are merged into one
642 // event with `predicted` set to false.
643 buffer.update(cx, |buffer, cx| {
644 let offset = Point::new(1, 0).to_offset(buffer);
645 let end = Point::new(1, 6).to_offset(buffer);
646 buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
647 });
648
649 let events = ep_store.update(cx, |ep_store, cx| {
650 ep_store.edit_history_for_project(&project, cx)
651 });
652 assert_eq!(
653 render_events_with_predicted(&events),
654 vec![indoc! {"
655 manual
656 @@ -1,5 +1,5 @@
657 -line 0
658 -line 1
659 +LINE ZERO
660 +LINE ONE
661 line 2
662 line 3
663 line 4
664 "}]
665 );
666
667 // Case 3: Accepted predictions have `predicted` set to true.
668 // Case 5: A manual edit that follows a predicted edit is not merged with the
669 // predicted edit, even if it is nearby.
670 ep_store.update(cx, |ep_store, cx| {
671 buffer.update(cx, |buffer, cx| {
672 let offset = Point::new(2, 0).to_offset(buffer);
673 let end = Point::new(2, 6).to_offset(buffer);
674 buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
675 });
676 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
677 });
678
679 let events = ep_store.update(cx, |ep_store, cx| {
680 ep_store.edit_history_for_project(&project, cx)
681 });
682 assert_eq!(
683 render_events_with_predicted(&events),
684 vec![
685 indoc! {"
686 manual
687 @@ -1,5 +1,5 @@
688 -line 0
689 -line 1
690 +LINE ZERO
691 +LINE ONE
692 line 2
693 line 3
694 line 4
695 "},
696 indoc! {"
697 predicted
698 @@ -1,6 +1,6 @@
699 LINE ZERO
700 LINE ONE
701 -line 2
702 +LINE TWO
703 line 3
704 line 4
705 line 5
706 "}
707 ]
708 );
709
710 // Case 4: Multiple successive accepted predictions near each other are merged
711 // into one event with `predicted` set to true.
712 ep_store.update(cx, |ep_store, cx| {
713 buffer.update(cx, |buffer, cx| {
714 let offset = Point::new(3, 0).to_offset(buffer);
715 let end = Point::new(3, 6).to_offset(buffer);
716 buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
717 });
718 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
719 });
720
721 let events = ep_store.update(cx, |ep_store, cx| {
722 ep_store.edit_history_for_project(&project, cx)
723 });
724 assert_eq!(
725 render_events_with_predicted(&events),
726 vec![
727 indoc! {"
728 manual
729 @@ -1,5 +1,5 @@
730 -line 0
731 -line 1
732 +LINE ZERO
733 +LINE ONE
734 line 2
735 line 3
736 line 4
737 "},
738 indoc! {"
739 predicted
740 @@ -1,7 +1,7 @@
741 LINE ZERO
742 LINE ONE
743 -line 2
744 -line 3
745 +LINE TWO
746 +LINE THREE
747 line 4
748 line 5
749 line 6
750 "}
751 ]
752 );
753
754 // Case 5 (continued): A manual edit that follows a predicted edit is not merged
755 // with the predicted edit, even if it is nearby.
756 buffer.update(cx, |buffer, cx| {
757 let offset = Point::new(4, 0).to_offset(buffer);
758 let end = Point::new(4, 6).to_offset(buffer);
759 buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
760 });
761
762 let events = ep_store.update(cx, |ep_store, cx| {
763 ep_store.edit_history_for_project(&project, cx)
764 });
765 assert_eq!(
766 render_events_with_predicted(&events),
767 vec![
768 indoc! {"
769 manual
770 @@ -1,5 +1,5 @@
771 -line 0
772 -line 1
773 +LINE ZERO
774 +LINE ONE
775 line 2
776 line 3
777 line 4
778 "},
779 indoc! {"
780 predicted
781 @@ -1,7 +1,7 @@
782 LINE ZERO
783 LINE ONE
784 -line 2
785 -line 3
786 +LINE TWO
787 +LINE THREE
788 line 4
789 line 5
790 line 6
791 "},
792 indoc! {"
793 manual
794 @@ -2,7 +2,7 @@
795 LINE ONE
796 LINE TWO
797 LINE THREE
798 -line 4
799 +LINE FOUR
800 line 5
801 line 6
802 line 7
803 "}
804 ]
805 );
806
807 // Case 6: If we then perform a manual edit at a *different* location (more than
808 // 8 lines away), then the edits at the prior location can be merged with each
809 // other, even if some are predicted and some are not. `predicted` means all
810 // constituent edits were predicted.
811 buffer.update(cx, |buffer, cx| {
812 let offset = Point::new(14, 0).to_offset(buffer);
813 let end = Point::new(14, 7).to_offset(buffer);
814 buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
815 });
816
817 let events = ep_store.update(cx, |ep_store, cx| {
818 ep_store.edit_history_for_project(&project, cx)
819 });
820 assert_eq!(
821 render_events_with_predicted(&events),
822 vec![
823 indoc! {"
824 manual
825 @@ -1,8 +1,8 @@
826 -line 0
827 -line 1
828 -line 2
829 -line 3
830 -line 4
831 +LINE ZERO
832 +LINE ONE
833 +LINE TWO
834 +LINE THREE
835 +LINE FOUR
836 line 5
837 line 6
838 line 7
839 "},
840 indoc! {"
841 manual
842 @@ -12,4 +12,4 @@
843 line 11
844 line 12
845 line 13
846 -line 14
847 +LINE FOURTEEN
848 "}
849 ]
850 );
851}
852
853#[gpui::test]
854async fn test_empty_prediction(cx: &mut TestAppContext) {
855 let (ep_store, mut requests) = init_test_with_fake_client(cx);
856 let fs = FakeFs::new(cx.executor());
857 fs.insert_tree(
858 "/root",
859 json!({
860 "foo.md": "Hello!\nHow\nBye\n"
861 }),
862 )
863 .await;
864 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
865
866 let buffer = project
867 .update(cx, |project, cx| {
868 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
869 project.open_buffer(path, cx)
870 })
871 .await
872 .unwrap();
873 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
874 let position = snapshot.anchor_before(language::Point::new(1, 3));
875
876 ep_store.update(cx, |ep_store, cx| {
877 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
878 });
879
880 let (request, respond_tx) = requests.predict.next().await.unwrap();
881 let response = model_response(&request, "");
882 let id = response.request_id.clone();
883 respond_tx.send(response).unwrap();
884
885 cx.run_until_parked();
886
887 ep_store.update(cx, |ep_store, cx| {
888 assert!(
889 ep_store
890 .prediction_at(&buffer, None, &project, cx)
891 .is_none()
892 );
893 });
894
895 // prediction is reported as rejected
896 let (reject_request, _) = requests.reject.next().await.unwrap();
897
898 assert_eq!(
899 &reject_request.rejections,
900 &[EditPredictionRejection {
901 request_id: id,
902 reason: EditPredictionRejectReason::Empty,
903 was_shown: false,
904 model_version: None,
905 }]
906 );
907}
908
909#[gpui::test]
910async fn test_interpolated_empty(cx: &mut TestAppContext) {
911 let (ep_store, mut requests) = init_test_with_fake_client(cx);
912 let fs = FakeFs::new(cx.executor());
913 fs.insert_tree(
914 "/root",
915 json!({
916 "foo.md": "Hello!\nHow\nBye\n"
917 }),
918 )
919 .await;
920 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
921
922 let buffer = project
923 .update(cx, |project, cx| {
924 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
925 project.open_buffer(path, cx)
926 })
927 .await
928 .unwrap();
929 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
930 let position = snapshot.anchor_before(language::Point::new(1, 3));
931
932 ep_store.update(cx, |ep_store, cx| {
933 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
934 });
935
936 let (request, respond_tx) = requests.predict.next().await.unwrap();
937
938 buffer.update(cx, |buffer, cx| {
939 buffer.set_text("Hello!\nHow are you?\nBye", cx);
940 });
941
942 let response = model_response(&request, SIMPLE_DIFF);
943 let id = response.request_id.clone();
944 respond_tx.send(response).unwrap();
945
946 cx.run_until_parked();
947
948 ep_store.update(cx, |ep_store, cx| {
949 assert!(
950 ep_store
951 .prediction_at(&buffer, None, &project, cx)
952 .is_none()
953 );
954 });
955
956 // prediction is reported as rejected
957 let (reject_request, _) = requests.reject.next().await.unwrap();
958
959 assert_eq!(
960 &reject_request.rejections,
961 &[EditPredictionRejection {
962 request_id: id,
963 reason: EditPredictionRejectReason::InterpolatedEmpty,
964 was_shown: false,
965 model_version: None,
966 }]
967 );
968}
969
970const SIMPLE_DIFF: &str = indoc! { r"
971 --- a/root/foo.md
972 +++ b/root/foo.md
973 @@ ... @@
974 Hello!
975 -How
976 +How are you?
977 Bye
978"};
979
980#[gpui::test]
981async fn test_replace_current(cx: &mut TestAppContext) {
982 let (ep_store, mut requests) = init_test_with_fake_client(cx);
983 let fs = FakeFs::new(cx.executor());
984 fs.insert_tree(
985 "/root",
986 json!({
987 "foo.md": "Hello!\nHow\nBye\n"
988 }),
989 )
990 .await;
991 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
992
993 let buffer = project
994 .update(cx, |project, cx| {
995 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
996 project.open_buffer(path, cx)
997 })
998 .await
999 .unwrap();
1000 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1001 let position = snapshot.anchor_before(language::Point::new(1, 3));
1002
1003 ep_store.update(cx, |ep_store, cx| {
1004 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1005 });
1006
1007 let (request, respond_tx) = requests.predict.next().await.unwrap();
1008 let first_response = model_response(&request, SIMPLE_DIFF);
1009 let first_id = first_response.request_id.clone();
1010 respond_tx.send(first_response).unwrap();
1011
1012 cx.run_until_parked();
1013
1014 ep_store.update(cx, |ep_store, cx| {
1015 assert_eq!(
1016 ep_store
1017 .prediction_at(&buffer, None, &project, cx)
1018 .unwrap()
1019 .id
1020 .0,
1021 first_id
1022 );
1023 });
1024
1025 // a second request is triggered
1026 ep_store.update(cx, |ep_store, cx| {
1027 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1028 });
1029
1030 let (request, respond_tx) = requests.predict.next().await.unwrap();
1031 let second_response = model_response(&request, SIMPLE_DIFF);
1032 let second_id = second_response.request_id.clone();
1033 respond_tx.send(second_response).unwrap();
1034
1035 cx.run_until_parked();
1036
1037 ep_store.update(cx, |ep_store, cx| {
1038 // second replaces first
1039 assert_eq!(
1040 ep_store
1041 .prediction_at(&buffer, None, &project, cx)
1042 .unwrap()
1043 .id
1044 .0,
1045 second_id
1046 );
1047 });
1048
1049 // first is reported as replaced
1050 let (reject_request, _) = requests.reject.next().await.unwrap();
1051
1052 assert_eq!(
1053 &reject_request.rejections,
1054 &[EditPredictionRejection {
1055 request_id: first_id,
1056 reason: EditPredictionRejectReason::Replaced,
1057 was_shown: false,
1058 model_version: None,
1059 }]
1060 );
1061}
1062
1063#[gpui::test]
1064async fn test_current_preferred(cx: &mut TestAppContext) {
1065 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1066 let fs = FakeFs::new(cx.executor());
1067 fs.insert_tree(
1068 "/root",
1069 json!({
1070 "foo.md": "Hello!\nHow\nBye\n"
1071 }),
1072 )
1073 .await;
1074 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1075
1076 let buffer = project
1077 .update(cx, |project, cx| {
1078 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1079 project.open_buffer(path, cx)
1080 })
1081 .await
1082 .unwrap();
1083 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1084 let position = snapshot.anchor_before(language::Point::new(1, 3));
1085
1086 ep_store.update(cx, |ep_store, cx| {
1087 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1088 });
1089
1090 let (request, respond_tx) = requests.predict.next().await.unwrap();
1091 let first_response = model_response(&request, SIMPLE_DIFF);
1092 let first_id = first_response.request_id.clone();
1093 respond_tx.send(first_response).unwrap();
1094
1095 cx.run_until_parked();
1096
1097 ep_store.update(cx, |ep_store, cx| {
1098 assert_eq!(
1099 ep_store
1100 .prediction_at(&buffer, None, &project, cx)
1101 .unwrap()
1102 .id
1103 .0,
1104 first_id
1105 );
1106 });
1107
1108 // a second request is triggered
1109 ep_store.update(cx, |ep_store, cx| {
1110 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1111 });
1112
1113 let (request, respond_tx) = requests.predict.next().await.unwrap();
1114 // worse than current prediction
1115 let second_response = model_response(
1116 &request,
1117 indoc! { r"
1118 --- a/root/foo.md
1119 +++ b/root/foo.md
1120 @@ ... @@
1121 Hello!
1122 -How
1123 +How are
1124 Bye
1125 "},
1126 );
1127 let second_id = second_response.request_id.clone();
1128 respond_tx.send(second_response).unwrap();
1129
1130 cx.run_until_parked();
1131
1132 ep_store.update(cx, |ep_store, cx| {
1133 // first is preferred over second
1134 assert_eq!(
1135 ep_store
1136 .prediction_at(&buffer, None, &project, cx)
1137 .unwrap()
1138 .id
1139 .0,
1140 first_id
1141 );
1142 });
1143
1144 // second is reported as rejected
1145 let (reject_request, _) = requests.reject.next().await.unwrap();
1146
1147 assert_eq!(
1148 &reject_request.rejections,
1149 &[EditPredictionRejection {
1150 request_id: second_id,
1151 reason: EditPredictionRejectReason::CurrentPreferred,
1152 was_shown: false,
1153 model_version: None,
1154 }]
1155 );
1156}
1157
1158#[gpui::test]
1159async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1160 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1161 let fs = FakeFs::new(cx.executor());
1162 fs.insert_tree(
1163 "/root",
1164 json!({
1165 "foo.md": "Hello!\nHow\nBye\n"
1166 }),
1167 )
1168 .await;
1169 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1170
1171 let buffer = project
1172 .update(cx, |project, cx| {
1173 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1174 project.open_buffer(path, cx)
1175 })
1176 .await
1177 .unwrap();
1178 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1179 let position = snapshot.anchor_before(language::Point::new(1, 3));
1180
1181 // start two refresh tasks
1182 ep_store.update(cx, |ep_store, cx| {
1183 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1184 });
1185
1186 let (request1, respond_first) = requests.predict.next().await.unwrap();
1187
1188 ep_store.update(cx, |ep_store, cx| {
1189 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1190 });
1191
1192 let (request, respond_second) = requests.predict.next().await.unwrap();
1193
1194 // wait for throttle
1195 cx.run_until_parked();
1196
1197 // second responds first
1198 let second_response = model_response(&request, SIMPLE_DIFF);
1199 let second_id = second_response.request_id.clone();
1200 respond_second.send(second_response).unwrap();
1201
1202 cx.run_until_parked();
1203
1204 ep_store.update(cx, |ep_store, cx| {
1205 // current prediction is second
1206 assert_eq!(
1207 ep_store
1208 .prediction_at(&buffer, None, &project, cx)
1209 .unwrap()
1210 .id
1211 .0,
1212 second_id
1213 );
1214 });
1215
1216 let first_response = model_response(&request1, SIMPLE_DIFF);
1217 let first_id = first_response.request_id.clone();
1218 respond_first.send(first_response).unwrap();
1219
1220 cx.run_until_parked();
1221
1222 ep_store.update(cx, |ep_store, cx| {
1223 // current prediction is still second, since first was cancelled
1224 assert_eq!(
1225 ep_store
1226 .prediction_at(&buffer, None, &project, cx)
1227 .unwrap()
1228 .id
1229 .0,
1230 second_id
1231 );
1232 });
1233
1234 // first is reported as rejected
1235 let (reject_request, _) = requests.reject.next().await.unwrap();
1236
1237 cx.run_until_parked();
1238
1239 assert_eq!(
1240 &reject_request.rejections,
1241 &[EditPredictionRejection {
1242 request_id: first_id,
1243 reason: EditPredictionRejectReason::Canceled,
1244 was_shown: false,
1245 model_version: None,
1246 }]
1247 );
1248}
1249
1250#[gpui::test]
1251async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1252 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1253 let fs = FakeFs::new(cx.executor());
1254 fs.insert_tree(
1255 "/root",
1256 json!({
1257 "foo.md": "Hello!\nHow\nBye\n"
1258 }),
1259 )
1260 .await;
1261 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1262
1263 let buffer = project
1264 .update(cx, |project, cx| {
1265 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1266 project.open_buffer(path, cx)
1267 })
1268 .await
1269 .unwrap();
1270 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1271 let position = snapshot.anchor_before(language::Point::new(1, 3));
1272
1273 // start two refresh tasks
1274 ep_store.update(cx, |ep_store, cx| {
1275 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1276 });
1277
1278 let (request1, respond_first) = requests.predict.next().await.unwrap();
1279
1280 ep_store.update(cx, |ep_store, cx| {
1281 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1282 });
1283
1284 let (request2, respond_second) = requests.predict.next().await.unwrap();
1285
1286 // wait for throttle, so requests are sent
1287 cx.run_until_parked();
1288
1289 ep_store.update(cx, |ep_store, cx| {
1290 // start a third request
1291 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1292
1293 // 2 are pending, so 2nd is cancelled
1294 assert_eq!(
1295 ep_store
1296 .get_or_init_project(&project, cx)
1297 .cancelled_predictions
1298 .iter()
1299 .copied()
1300 .collect::<Vec<_>>(),
1301 [1]
1302 );
1303 });
1304
1305 // wait for throttle
1306 cx.run_until_parked();
1307
1308 let (request3, respond_third) = requests.predict.next().await.unwrap();
1309
1310 let first_response = model_response(&request1, SIMPLE_DIFF);
1311 let first_id = first_response.request_id.clone();
1312 respond_first.send(first_response).unwrap();
1313
1314 cx.run_until_parked();
1315
1316 ep_store.update(cx, |ep_store, cx| {
1317 // current prediction is first
1318 assert_eq!(
1319 ep_store
1320 .prediction_at(&buffer, None, &project, cx)
1321 .unwrap()
1322 .id
1323 .0,
1324 first_id
1325 );
1326 });
1327
1328 let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1329 let cancelled_id = cancelled_response.request_id.clone();
1330 respond_second.send(cancelled_response).unwrap();
1331
1332 cx.run_until_parked();
1333
1334 ep_store.update(cx, |ep_store, cx| {
1335 // current prediction is still first, since second was cancelled
1336 assert_eq!(
1337 ep_store
1338 .prediction_at(&buffer, None, &project, cx)
1339 .unwrap()
1340 .id
1341 .0,
1342 first_id
1343 );
1344 });
1345
1346 let third_response = model_response(&request3, SIMPLE_DIFF);
1347 let third_response_id = third_response.request_id.clone();
1348 respond_third.send(third_response).unwrap();
1349
1350 cx.run_until_parked();
1351
1352 ep_store.update(cx, |ep_store, cx| {
1353 // third completes and replaces first
1354 assert_eq!(
1355 ep_store
1356 .prediction_at(&buffer, None, &project, cx)
1357 .unwrap()
1358 .id
1359 .0,
1360 third_response_id
1361 );
1362 });
1363
1364 // second is reported as rejected
1365 let (reject_request, _) = requests.reject.next().await.unwrap();
1366
1367 cx.run_until_parked();
1368
1369 assert_eq!(
1370 &reject_request.rejections,
1371 &[
1372 EditPredictionRejection {
1373 request_id: cancelled_id,
1374 reason: EditPredictionRejectReason::Canceled,
1375 was_shown: false,
1376 model_version: None,
1377 },
1378 EditPredictionRejection {
1379 request_id: first_id,
1380 reason: EditPredictionRejectReason::Replaced,
1381 was_shown: false,
1382 model_version: None,
1383 }
1384 ]
1385 );
1386}
1387
1388#[gpui::test]
1389async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
1390 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1391
1392 let fs = FakeFs::new(cx.executor());
1393 fs.insert_tree(
1394 "/root",
1395 json!({
1396 "foo.md": "Hello!\nHow\nBye\n",
1397 "bar.md": "Hola!\nComo\nAdios\n"
1398 }),
1399 )
1400 .await;
1401 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1402
1403 let buffer = project
1404 .update(cx, |project, cx| {
1405 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1406 project.set_active_path(Some(path.clone()), cx);
1407 project.open_buffer(path, cx)
1408 })
1409 .await
1410 .unwrap();
1411 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1412 let position = snapshot.anchor_before(language::Point::new(1, 3));
1413
1414 ep_store.update(cx, |ep_store, cx| {
1415 ep_store.register_project(&project, cx);
1416 ep_store.register_buffer(&buffer, &project, cx);
1417 });
1418
1419 // First edit request - no prior edit, so not throttled.
1420 ep_store.update(cx, |ep_store, cx| {
1421 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1422 });
1423 let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
1424 edit_response_tx.send(empty_response()).unwrap();
1425 cx.run_until_parked();
1426
1427 let diagnostic = lsp::Diagnostic {
1428 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1429 severity: Some(lsp::DiagnosticSeverity::ERROR),
1430 message: "Sentence is incomplete".to_string(),
1431 ..Default::default()
1432 };
1433
1434 // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
1435 project.update(cx, |project, cx| {
1436 project.lsp_store().update(cx, |lsp_store, cx| {
1437 lsp_store
1438 .update_diagnostics(
1439 LanguageServerId(0),
1440 lsp::PublishDiagnosticsParams {
1441 uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
1442 diagnostics: vec![diagnostic],
1443 version: None,
1444 },
1445 None,
1446 language::DiagnosticSourceKind::Pushed,
1447 &[],
1448 cx,
1449 )
1450 .unwrap();
1451 });
1452 });
1453 let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
1454 jump_response_tx.send(empty_response()).unwrap();
1455 cx.run_until_parked();
1456
1457 // Second edit request - should be throttled by the first edit.
1458 ep_store.update(cx, |ep_store, cx| {
1459 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1460 });
1461 assert_no_predict_request_ready(&mut requests.predict);
1462
1463 // Second jump request - should be throttled by the first jump.
1464 ep_store.update(cx, |ep_store, cx| {
1465 ep_store.refresh_prediction_from_diagnostics(
1466 project.clone(),
1467 DiagnosticSearchScope::Global,
1468 cx,
1469 );
1470 });
1471 assert_no_predict_request_ready(&mut requests.predict);
1472
1473 // Wait for both throttles to expire.
1474 cx.background_executor
1475 .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
1476 cx.background_executor.run_until_parked();
1477 cx.run_until_parked();
1478
1479 // Both requests should now go through.
1480 let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
1481 response_tx_1.send(empty_response()).unwrap();
1482 cx.run_until_parked();
1483
1484 let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
1485 response_tx_2.send(empty_response()).unwrap();
1486 cx.run_until_parked();
1487}
1488
1489#[gpui::test]
1490async fn test_rejections_flushing(cx: &mut TestAppContext) {
1491 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1492
1493 ep_store.update(cx, |ep_store, cx| {
1494 ep_store.reject_prediction(
1495 EditPredictionId("test-1".into()),
1496 EditPredictionRejectReason::Discarded,
1497 false,
1498 None,
1499 cx,
1500 );
1501 ep_store.reject_prediction(
1502 EditPredictionId("test-2".into()),
1503 EditPredictionRejectReason::Canceled,
1504 true,
1505 None,
1506 cx,
1507 );
1508 });
1509
1510 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1511 cx.run_until_parked();
1512
1513 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1514 respond_tx.send(()).unwrap();
1515
1516 // batched
1517 assert_eq!(reject_request.rejections.len(), 2);
1518 assert_eq!(
1519 reject_request.rejections[0],
1520 EditPredictionRejection {
1521 request_id: "test-1".to_string(),
1522 reason: EditPredictionRejectReason::Discarded,
1523 was_shown: false,
1524 model_version: None,
1525 }
1526 );
1527 assert_eq!(
1528 reject_request.rejections[1],
1529 EditPredictionRejection {
1530 request_id: "test-2".to_string(),
1531 reason: EditPredictionRejectReason::Canceled,
1532 was_shown: true,
1533 model_version: None,
1534 }
1535 );
1536
1537 // Reaching batch size limit sends without debounce
1538 ep_store.update(cx, |ep_store, cx| {
1539 for i in 0..70 {
1540 ep_store.reject_prediction(
1541 EditPredictionId(format!("batch-{}", i).into()),
1542 EditPredictionRejectReason::Discarded,
1543 false,
1544 None,
1545 cx,
1546 );
1547 }
1548 });
1549
1550 // First MAX/2 items are sent immediately
1551 cx.run_until_parked();
1552 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1553 respond_tx.send(()).unwrap();
1554
1555 assert_eq!(reject_request.rejections.len(), 50);
1556 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1557 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1558
1559 // Remaining items are debounced with the next batch
1560 cx.executor().advance_clock(Duration::from_secs(15));
1561 cx.run_until_parked();
1562
1563 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1564 respond_tx.send(()).unwrap();
1565
1566 assert_eq!(reject_request.rejections.len(), 20);
1567 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1568 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1569
1570 // Request failure
1571 ep_store.update(cx, |ep_store, cx| {
1572 ep_store.reject_prediction(
1573 EditPredictionId("retry-1".into()),
1574 EditPredictionRejectReason::Discarded,
1575 false,
1576 None,
1577 cx,
1578 );
1579 });
1580
1581 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1582 cx.run_until_parked();
1583
1584 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1585 assert_eq!(reject_request.rejections.len(), 1);
1586 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1587 // Simulate failure
1588 drop(_respond_tx);
1589
1590 // Add another rejection
1591 ep_store.update(cx, |ep_store, cx| {
1592 ep_store.reject_prediction(
1593 EditPredictionId("retry-2".into()),
1594 EditPredictionRejectReason::Discarded,
1595 false,
1596 None,
1597 cx,
1598 );
1599 });
1600
1601 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1602 cx.run_until_parked();
1603
1604 // Retry should include both the failed item and the new one
1605 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1606 respond_tx.send(()).unwrap();
1607
1608 assert_eq!(reject_request.rejections.len(), 2);
1609 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1610 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1611}
1612
1613// Skipped until we start including diagnostics in prompt
1614// #[gpui::test]
1615// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1616// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1617// let fs = FakeFs::new(cx.executor());
1618// fs.insert_tree(
1619// "/root",
1620// json!({
1621// "foo.md": "Hello!\nBye"
1622// }),
1623// )
1624// .await;
1625// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1626
1627// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1628// let diagnostic = lsp::Diagnostic {
1629// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1630// severity: Some(lsp::DiagnosticSeverity::ERROR),
1631// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1632// ..Default::default()
1633// };
1634
1635// project.update(cx, |project, cx| {
1636// project.lsp_store().update(cx, |lsp_store, cx| {
1637// // Create some diagnostics
1638// lsp_store
1639// .update_diagnostics(
1640// LanguageServerId(0),
1641// lsp::PublishDiagnosticsParams {
1642// uri: path_to_buffer_uri.clone(),
1643// diagnostics: vec![diagnostic],
1644// version: None,
1645// },
1646// None,
1647// language::DiagnosticSourceKind::Pushed,
1648// &[],
1649// cx,
1650// )
1651// .unwrap();
1652// });
1653// });
1654
1655// let buffer = project
1656// .update(cx, |project, cx| {
1657// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1658// project.open_buffer(path, cx)
1659// })
1660// .await
1661// .unwrap();
1662
1663// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1664// let position = snapshot.anchor_before(language::Point::new(0, 0));
1665
1666// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1667// ep_store.request_prediction(&project, &buffer, position, cx)
1668// });
1669
1670// let (request, _respond_tx) = req_rx.next().await.unwrap();
1671
1672// assert_eq!(request.diagnostic_groups.len(), 1);
1673// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1674// .unwrap();
1675// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1676// assert_eq!(
1677// value,
1678// json!({
1679// "entries": [{
1680// "range": {
1681// "start": 8,
1682// "end": 10
1683// },
1684// "diagnostic": {
1685// "source": null,
1686// "code": null,
1687// "code_description": null,
1688// "severity": 1,
1689// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1690// "markdown": null,
1691// "group_id": 0,
1692// "is_primary": true,
1693// "is_disk_based": false,
1694// "is_unnecessary": false,
1695// "source_kind": "Pushed",
1696// "data": null,
1697// "underline": true
1698// }
1699// }],
1700// "primary_ix": 0
1701// })
1702// );
1703// }
1704
1705// Generate a model response that would apply the given diff to the active file.
1706fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1707 let editable_range =
1708 zeta_prompt::excerpt_range_for_format(Default::default(), &request.input.excerpt_ranges).1;
1709 let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string();
1710 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1711
1712 PredictEditsV3Response {
1713 request_id: Uuid::new_v4().to_string(),
1714 editable_range,
1715 output: new_excerpt,
1716 model_version: None,
1717 }
1718}
1719
1720fn empty_response() -> PredictEditsV3Response {
1721 PredictEditsV3Response {
1722 request_id: Uuid::new_v4().to_string(),
1723 editable_range: 0..0,
1724 output: String::new(),
1725 model_version: None,
1726 }
1727}
1728
1729fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1730 zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1731}
1732
1733fn assert_no_predict_request_ready(
1734 requests: &mut mpsc::UnboundedReceiver<(
1735 PredictEditsV3Request,
1736 oneshot::Sender<PredictEditsV3Response>,
1737 )>,
1738) {
1739 if requests.next().now_or_never().flatten().is_some() {
1740 panic!("Unexpected prediction request while throttled.");
1741 }
1742}
1743
1744struct RequestChannels {
1745 predict: mpsc::UnboundedReceiver<(
1746 PredictEditsV3Request,
1747 oneshot::Sender<PredictEditsV3Response>,
1748 )>,
1749 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1750}
1751
1752fn init_test_with_fake_client(
1753 cx: &mut TestAppContext,
1754) -> (Entity<EditPredictionStore>, RequestChannels) {
1755 cx.update(move |cx| {
1756 let settings_store = SettingsStore::test(cx);
1757 cx.set_global(settings_store);
1758 zlog::init_test();
1759
1760 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1761 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1762
1763 let http_client = FakeHttpClient::create({
1764 move |req| {
1765 let uri = req.uri().path().to_string();
1766 let mut body = req.into_body();
1767 let predict_req_tx = predict_req_tx.clone();
1768 let reject_req_tx = reject_req_tx.clone();
1769 async move {
1770 let resp = match uri.as_str() {
1771 "/client/llm_tokens" => serde_json::to_string(&json!({
1772 "token": "test"
1773 }))
1774 .unwrap(),
1775 "/predict_edits/v3" => {
1776 let mut buf = Vec::new();
1777 body.read_to_end(&mut buf).await.ok();
1778 let decompressed = zstd::decode_all(&buf[..]).unwrap();
1779 let req = serde_json::from_slice(&decompressed).unwrap();
1780
1781 let (res_tx, res_rx) = oneshot::channel();
1782 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1783 serde_json::to_string(&res_rx.await?).unwrap()
1784 }
1785 "/predict_edits/reject" => {
1786 let mut buf = Vec::new();
1787 body.read_to_end(&mut buf).await.ok();
1788 let req = serde_json::from_slice(&buf).unwrap();
1789
1790 let (res_tx, res_rx) = oneshot::channel();
1791 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1792 serde_json::to_string(&res_rx.await?).unwrap()
1793 }
1794 _ => {
1795 panic!("Unexpected path: {}", uri)
1796 }
1797 };
1798
1799 Ok(Response::builder().body(resp.into()).unwrap())
1800 }
1801 }
1802 });
1803
1804 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1805 client.cloud_client().set_credentials(1, "test".into());
1806
1807 language_model::init(client.clone(), cx);
1808
1809 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1810 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1811
1812 (
1813 ep_store,
1814 RequestChannels {
1815 predict: predict_req_rx,
1816 reject: reject_req_rx,
1817 },
1818 )
1819 })
1820}
1821
1822#[gpui::test]
1823async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1824 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1825 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1826 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1827 });
1828
1829 let edit_preview = cx
1830 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1831 .await;
1832
1833 let prediction = EditPrediction {
1834 edits,
1835 cursor_position: None,
1836 edit_preview,
1837 buffer: buffer.clone(),
1838 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1839 id: EditPredictionId("the-id".into()),
1840 inputs: ZetaPromptInput {
1841 events: Default::default(),
1842 related_files: Default::default(),
1843 cursor_path: Path::new("").into(),
1844 cursor_excerpt: "".into(),
1845 cursor_offset_in_excerpt: 0,
1846 excerpt_start_row: None,
1847 excerpt_ranges: Default::default(),
1848 experiment: None,
1849 in_open_source_repo: false,
1850 can_collect_data: false,
1851 },
1852 buffer_snapshotted_at: Instant::now(),
1853 response_received_at: Instant::now(),
1854 model_version: None,
1855 };
1856
1857 cx.update(|cx| {
1858 assert_eq!(
1859 from_completion_edits(
1860 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1861 &buffer,
1862 cx
1863 ),
1864 vec![(2..5, "REM".into()), (9..11, "".into())]
1865 );
1866
1867 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1868 assert_eq!(
1869 from_completion_edits(
1870 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1871 &buffer,
1872 cx
1873 ),
1874 vec![(2..2, "REM".into()), (6..8, "".into())]
1875 );
1876
1877 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1878 assert_eq!(
1879 from_completion_edits(
1880 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1881 &buffer,
1882 cx
1883 ),
1884 vec![(2..5, "REM".into()), (9..11, "".into())]
1885 );
1886
1887 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1888 assert_eq!(
1889 from_completion_edits(
1890 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1891 &buffer,
1892 cx
1893 ),
1894 vec![(3..3, "EM".into()), (7..9, "".into())]
1895 );
1896
1897 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1898 assert_eq!(
1899 from_completion_edits(
1900 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1901 &buffer,
1902 cx
1903 ),
1904 vec![(4..4, "M".into()), (8..10, "".into())]
1905 );
1906
1907 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1908 assert_eq!(
1909 from_completion_edits(
1910 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1911 &buffer,
1912 cx
1913 ),
1914 vec![(9..11, "".into())]
1915 );
1916
1917 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1918 assert_eq!(
1919 from_completion_edits(
1920 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1921 &buffer,
1922 cx
1923 ),
1924 vec![(4..4, "M".into()), (8..10, "".into())]
1925 );
1926
1927 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1928 assert_eq!(
1929 from_completion_edits(
1930 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1931 &buffer,
1932 cx
1933 ),
1934 vec![(4..4, "M".into())]
1935 );
1936
1937 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1938 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1939 })
1940}
1941
1942#[gpui::test]
1943async fn test_clean_up_diff(cx: &mut TestAppContext) {
1944 init_test(cx);
1945
1946 assert_eq!(
1947 apply_edit_prediction(
1948 indoc! {"
1949 fn main() {
1950 let word_1 = \"lorem\";
1951 let range = word.len()..word.len();
1952 }
1953 "},
1954 indoc! {"
1955 fn main() {
1956 let word_1 = \"lorem\";
1957 let range = word_1.len()..word_1.len();
1958 }
1959 "},
1960 cx,
1961 )
1962 .await,
1963 indoc! {"
1964 fn main() {
1965 let word_1 = \"lorem\";
1966 let range = word_1.len()..word_1.len();
1967 }
1968 "},
1969 );
1970
1971 assert_eq!(
1972 apply_edit_prediction(
1973 indoc! {"
1974 fn main() {
1975 let story = \"the quick\"
1976 }
1977 "},
1978 indoc! {"
1979 fn main() {
1980 let story = \"the quick brown fox jumps over the lazy dog\";
1981 }
1982 "},
1983 cx,
1984 )
1985 .await,
1986 indoc! {"
1987 fn main() {
1988 let story = \"the quick brown fox jumps over the lazy dog\";
1989 }
1990 "},
1991 );
1992}
1993
1994#[gpui::test]
1995async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1996 init_test(cx);
1997
1998 let buffer_content = "lorem\n";
1999 let completion_response = "lorem\nipsum\n";
2000
2001 assert_eq!(
2002 apply_edit_prediction(buffer_content, completion_response, cx).await,
2003 "lorem\nipsum\n"
2004 );
2005}
2006
2007#[gpui::test]
2008async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
2009 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
2010 // When the buffer ends without a trailing newline, but the model returns output
2011 // with a trailing newline, zeta2 should normalize both sides before diffing
2012 // so no spurious newline is inserted.
2013 let (ep_store, mut requests) = init_test_with_fake_client(cx);
2014 let fs = FakeFs::new(cx.executor());
2015
2016 // Single line buffer with no trailing newline
2017 fs.insert_tree(
2018 "/root",
2019 json!({
2020 "foo.txt": "hello"
2021 }),
2022 )
2023 .await;
2024 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2025
2026 let buffer = project
2027 .update(cx, |project, cx| {
2028 let path = project
2029 .find_project_path(path!("root/foo.txt"), cx)
2030 .unwrap();
2031 project.open_buffer(path, cx)
2032 })
2033 .await
2034 .unwrap();
2035
2036 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2037 let position = snapshot.anchor_before(language::Point::new(0, 5));
2038
2039 ep_store.update(cx, |ep_store, cx| {
2040 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2041 });
2042
2043 let (request, respond_tx) = requests.predict.next().await.unwrap();
2044
2045 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
2046 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
2047 let excerpt_length = request.input.cursor_excerpt.len();
2048 let response = PredictEditsV3Response {
2049 request_id: Uuid::new_v4().to_string(),
2050 output: "hello world\n".to_string(),
2051 editable_range: 0..excerpt_length,
2052 model_version: None,
2053 };
2054 respond_tx.send(response).unwrap();
2055
2056 cx.run_until_parked();
2057
2058 // The prediction should insert " world" without adding a newline
2059 ep_store.update(cx, |ep_store, cx| {
2060 let prediction = ep_store
2061 .prediction_at(&buffer, None, &project, cx)
2062 .expect("should have prediction");
2063 let edits: Vec<_> = prediction
2064 .edits
2065 .iter()
2066 .map(|(range, text)| {
2067 let snapshot = buffer.read(cx).snapshot();
2068 (range.to_offset(&snapshot), text.clone())
2069 })
2070 .collect();
2071 assert_eq!(edits, vec![(5..5, " world".into())]);
2072 });
2073}
2074
2075fn init_test(cx: &mut TestAppContext) {
2076 cx.update(|cx| {
2077 let settings_store = SettingsStore::test(cx);
2078 cx.set_global(settings_store);
2079 });
2080}
2081
2082async fn apply_edit_prediction(
2083 buffer_content: &str,
2084 completion_response: &str,
2085 cx: &mut TestAppContext,
2086) -> String {
2087 let fs = project::FakeFs::new(cx.executor());
2088 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2089 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2090 let (ep_store, response) = make_test_ep_store(&project, cx).await;
2091 *response.lock() = completion_response.to_string();
2092 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2093 buffer.update(cx, |buffer, cx| {
2094 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2095 });
2096 buffer.read_with(cx, |buffer, _| buffer.text())
2097}
2098
2099async fn run_edit_prediction(
2100 buffer: &Entity<Buffer>,
2101 project: &Entity<Project>,
2102 ep_store: &Entity<EditPredictionStore>,
2103 cx: &mut TestAppContext,
2104) -> EditPrediction {
2105 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2106 ep_store.update(cx, |ep_store, cx| {
2107 ep_store.register_buffer(buffer, &project, cx)
2108 });
2109 cx.background_executor.run_until_parked();
2110 let prediction_task = ep_store.update(cx, |ep_store, cx| {
2111 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2112 });
2113 prediction_task.await.unwrap().unwrap().prediction.unwrap()
2114}
2115
2116async fn make_test_ep_store(
2117 project: &Entity<Project>,
2118 cx: &mut TestAppContext,
2119) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
2120 let default_response = "hello world\n".to_string();
2121 let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
2122 let http_client = FakeHttpClient::create({
2123 let completion_response = completion_response.clone();
2124 let mut next_request_id = 0;
2125 move |req| {
2126 let completion_response = completion_response.clone();
2127 let method = req.method().clone();
2128 let uri = req.uri().path().to_string();
2129 let mut body = req.into_body();
2130 async move {
2131 match (method, uri.as_str()) {
2132 (Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2133 .status(200)
2134 .body(
2135 serde_json::to_string(&CreateLlmTokenResponse {
2136 token: LlmToken("the-llm-token".to_string()),
2137 })
2138 .unwrap()
2139 .into(),
2140 )
2141 .unwrap()),
2142 (Method::POST, "/predict_edits/v3") => {
2143 let mut buf = Vec::new();
2144 body.read_to_end(&mut buf).await.ok();
2145 let decompressed = zstd::decode_all(&buf[..]).unwrap();
2146 let req: PredictEditsV3Request =
2147 serde_json::from_slice(&decompressed).unwrap();
2148
2149 next_request_id += 1;
2150 Ok(http_client::Response::builder()
2151 .status(200)
2152 .body(
2153 serde_json::to_string(&PredictEditsV3Response {
2154 request_id: format!("request-{next_request_id}"),
2155 editable_range: 0..req.input.cursor_excerpt.len(),
2156 output: completion_response.lock().clone(),
2157 model_version: None,
2158 })
2159 .unwrap()
2160 .into(),
2161 )
2162 .unwrap())
2163 }
2164 _ => Ok(http_client::Response::builder()
2165 .status(404)
2166 .body("Not Found".to_string().into())
2167 .unwrap()),
2168 }
2169 }
2170 }
2171 });
2172
2173 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2174 cx.update(|cx| {
2175 RefreshLlmTokenListener::register(client.clone(), cx);
2176 });
2177 let _server = FakeServer::for_client(42, &client, cx).await;
2178
2179 let ep_store = cx.new(|cx| {
2180 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2181 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2182
2183 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2184 for worktree in worktrees {
2185 let worktree_id = worktree.read(cx).id();
2186 ep_store
2187 .get_or_init_project(project, cx)
2188 .license_detection_watchers
2189 .entry(worktree_id)
2190 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2191 }
2192
2193 ep_store
2194 });
2195
2196 (ep_store, completion_response)
2197}
2198
2199fn to_completion_edits(
2200 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2201 buffer: &Entity<Buffer>,
2202 cx: &App,
2203) -> Vec<(Range<Anchor>, Arc<str>)> {
2204 let buffer = buffer.read(cx);
2205 iterator
2206 .into_iter()
2207 .map(|(range, text)| {
2208 (
2209 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2210 text,
2211 )
2212 })
2213 .collect()
2214}
2215
2216fn from_completion_edits(
2217 editor_edits: &[(Range<Anchor>, Arc<str>)],
2218 buffer: &Entity<Buffer>,
2219 cx: &App,
2220) -> Vec<(Range<usize>, Arc<str>)> {
2221 let buffer = buffer.read(cx);
2222 editor_edits
2223 .iter()
2224 .map(|(range, text)| {
2225 (
2226 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2227 text.clone(),
2228 )
2229 })
2230 .collect()
2231}
2232
2233#[gpui::test]
2234async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2235 init_test(cx);
2236
2237 let fs = FakeFs::new(cx.executor());
2238 fs.insert_tree(
2239 "/project",
2240 serde_json::json!({
2241 "main.rs": "fn main() {\n \n}\n"
2242 }),
2243 )
2244 .await;
2245
2246 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2247
2248 let http_client = FakeHttpClient::create(|_req| async move {
2249 Ok(gpui::http_client::Response::builder()
2250 .status(401)
2251 .body("Unauthorized".into())
2252 .unwrap())
2253 });
2254
2255 let client =
2256 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2257 cx.update(|cx| {
2258 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2259 });
2260
2261 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2262
2263 let buffer = project
2264 .update(cx, |project, cx| {
2265 let path = project
2266 .find_project_path(path!("/project/main.rs"), cx)
2267 .unwrap();
2268 project.open_buffer(path, cx)
2269 })
2270 .await
2271 .unwrap();
2272
2273 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2274 ep_store.update(cx, |ep_store, cx| {
2275 ep_store.register_buffer(&buffer, &project, cx)
2276 });
2277 cx.background_executor.run_until_parked();
2278
2279 let completion_task = ep_store.update(cx, |ep_store, cx| {
2280 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2281 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2282 });
2283
2284 let result = completion_task.await;
2285 assert!(
2286 result.is_err(),
2287 "Without authentication and without custom URL, prediction should fail"
2288 );
2289}
2290
2291#[gpui::test]
2292fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2293 let buffer = cx.new(|cx| {
2294 Buffer::local(
2295 indoc! {"
2296 zero
2297 one
2298 two
2299 three
2300 four
2301 five
2302 six
2303 seven
2304 eight
2305 nine
2306 ten
2307 eleven
2308 twelve
2309 thirteen
2310 fourteen
2311 fifteen
2312 sixteen
2313 seventeen
2314 eighteen
2315 nineteen
2316 twenty
2317 twenty-one
2318 twenty-two
2319 twenty-three
2320 twenty-four
2321 "},
2322 cx,
2323 )
2324 });
2325
2326 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2327
2328 buffer.update(cx, |buffer, cx| {
2329 let point = Point::new(12, 0);
2330 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2331 let point = Point::new(8, 0);
2332 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2333 });
2334
2335 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2336
2337 let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2338
2339 assert_eq!(
2340 diff,
2341 indoc! {"
2342 @@ -6,10 +6,12 @@
2343 five
2344 six
2345 seven
2346 +FIRST INSERTION
2347 eight
2348 nine
2349 ten
2350 eleven
2351 +SECOND INSERTION
2352 twelve
2353 thirteen
2354 fourteen
2355 "}
2356 );
2357}
2358
2359#[gpui::test]
2360async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
2361 fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
2362 let collab_replica = clock::ReplicaId::new(10);
2363 let anchor = buffer.read_with(cx, |buffer, _| {
2364 buffer.snapshot().anchor_before(Point::new(row, 0))
2365 });
2366 let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
2367 id: 1,
2368 start: anchor,
2369 end: anchor,
2370 reversed: false,
2371 goal: SelectionGoal::None,
2372 }]);
2373 buffer.update(cx, |buffer, cx| {
2374 buffer.apply_ops(
2375 [Operation::UpdateSelections {
2376 selections,
2377 lamport_timestamp: clock::Lamport {
2378 replica_id: collab_replica,
2379 value: 1,
2380 },
2381 line_mode: false,
2382 cursor_shape: CursorShape::Bar,
2383 }],
2384 cx,
2385 );
2386 });
2387 }
2388
2389 fn publish_diagnostics(
2390 uri_path: &'static str,
2391 rows: &[u32],
2392 project: &Entity<Project>,
2393 cx: &mut TestAppContext,
2394 ) {
2395 let diagnostics: Vec<_> = rows
2396 .iter()
2397 .map(|&row| lsp::Diagnostic {
2398 range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
2399 severity: Some(lsp::DiagnosticSeverity::ERROR),
2400 message: format!("error at row {row}"),
2401 ..Default::default()
2402 })
2403 .collect();
2404 project.update(cx, |project, cx| {
2405 project.lsp_store().update(cx, |lsp_store, cx| {
2406 lsp_store
2407 .update_diagnostics(
2408 LanguageServerId(0),
2409 lsp::PublishDiagnosticsParams {
2410 uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
2411 diagnostics,
2412 version: None,
2413 },
2414 None,
2415 language::DiagnosticSourceKind::Pushed,
2416 &[],
2417 cx,
2418 )
2419 .expect("failed to update diagnostics");
2420 });
2421 });
2422 }
2423
2424 init_test(cx);
2425
2426 let mut lines = String::new();
2427 for i in 0..60 {
2428 lines.push_str(&format!("line {i}\n"));
2429 }
2430
2431 let fs = FakeFs::new(cx.executor());
2432 fs.insert_tree(
2433 "/root",
2434 json!({
2435 "active.txt": lines,
2436 "collab_file.txt": "error here\nsecond line\n",
2437 "free_file.txt": "another error\nsecond line\n",
2438 }),
2439 )
2440 .await;
2441 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2442
2443 let active_buffer = project
2444 .update(cx, |project, cx| {
2445 let path = project
2446 .find_project_path(path!("/root/active.txt"), cx)
2447 .expect("active.txt not found");
2448 project.set_active_path(Some(path.clone()), cx);
2449 project.open_buffer(path, cx)
2450 })
2451 .await
2452 .expect("failed to open active buffer");
2453
2454 set_collaborator_cursor(&active_buffer, 5, cx);
2455
2456 publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
2457
2458 cx.run_until_parked();
2459
2460 let cursor_point = Point::new(25, 0);
2461 let empty_search_range: Range<Point> = Default::default();
2462
2463 let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2464 let result = EditPredictionStore::next_diagnostic_location(
2465 active_buffer.clone(),
2466 &snapshot,
2467 empty_search_range.clone(),
2468 cursor_point,
2469 &project,
2470 &mut cx.to_async(),
2471 )
2472 .await
2473 .expect("next_diagnostic_location failed");
2474
2475 let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
2476 assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
2477 let result_row = result_buffer.read_with(cx, |buffer, _| {
2478 result_anchor.to_point(&buffer.snapshot()).row
2479 });
2480 assert_ne!(
2481 result_row, 3,
2482 "row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
2483 );
2484 assert!(
2485 result_row == 25 || result_row == 50,
2486 "expected row 25 or 50, got {result_row}"
2487 );
2488
2489 let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2490 let near_cursor_point = Point::new(4, 0);
2491 let result_near = EditPredictionStore::next_diagnostic_location(
2492 active_buffer.clone(),
2493 &snapshot_near,
2494 empty_search_range.clone(),
2495 near_cursor_point,
2496 &project,
2497 &mut cx.to_async(),
2498 )
2499 .await
2500 .expect("next_diagnostic_location failed");
2501
2502 let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
2503 let near_row =
2504 active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
2505 assert_eq!(
2506 near_row, 3,
2507 "row 3 should be included when local cursor (row 4) is also near the collaborator"
2508 );
2509
2510 let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2511 let far_cursor_point = Point::new(50, 0);
2512 let result_far = EditPredictionStore::next_diagnostic_location(
2513 active_buffer.clone(),
2514 &snapshot_far,
2515 empty_search_range.clone(),
2516 far_cursor_point,
2517 &project,
2518 &mut cx.to_async(),
2519 )
2520 .await
2521 .expect("next_diagnostic_location failed");
2522
2523 let (_, far_anchor) = result_far.expect("expected a diagnostic location");
2524 let far_row =
2525 active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
2526 assert_eq!(
2527 far_row, 50,
2528 "row 50 is near local cursor (row 50) and far from collaborator, should be picked"
2529 );
2530
2531 publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
2532 publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
2533 cx.run_until_parked();
2534
2535 let collab_buffer = project
2536 .update(cx, |project, cx| {
2537 let path = project
2538 .find_project_path(path!("/root/collab_file.txt"), cx)
2539 .expect("collab_file.txt not found");
2540 project.open_buffer(path, cx)
2541 })
2542 .await
2543 .expect("failed to open collab buffer");
2544
2545 set_collaborator_cursor(&collab_buffer, 0, cx);
2546 cx.run_until_parked();
2547
2548 let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
2549 let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2550 let result_cross = EditPredictionStore::next_diagnostic_location(
2551 active_buffer.clone(),
2552 &snapshot_cross,
2553 no_same_file_search_range,
2554 Point::new(0, 0),
2555 &project,
2556 &mut cx.to_async(),
2557 )
2558 .await
2559 .expect("cross-file next_diagnostic_location failed");
2560
2561 let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
2562 let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
2563 buffer
2564 .file()
2565 .expect("buffer should have a file")
2566 .full_path(cx)
2567 });
2568 assert_eq!(
2569 cross_path,
2570 Path::new(path!("root/free_file.txt")),
2571 "should skip collab_file.txt (has collaborator) and pick free_file.txt"
2572 );
2573}
2574
2575#[gpui::test]
2576async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
2577 let (ep_store, _requests) = init_test_with_fake_client(cx);
2578 let fs = FakeFs::new(cx.executor());
2579
2580 // Buffer with two clearly separated regions:
2581 // Region A = lines 0-9 (offsets 0..50)
2582 // Region B = lines 20-29 (offsets 105..155)
2583 // A big gap in between so edits in one region never overlap the other.
2584 let mut content = String::new();
2585 for i in 0..30 {
2586 content.push_str(&format!("line {i:02}\n"));
2587 }
2588
2589 fs.insert_tree(
2590 "/root",
2591 json!({
2592 "foo.md": content.clone()
2593 }),
2594 )
2595 .await;
2596 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2597
2598 let buffer = project
2599 .update(cx, |project, cx| {
2600 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2601 project.open_buffer(path, cx)
2602 })
2603 .await
2604 .unwrap();
2605
2606 let settled_events: Arc<Mutex<Vec<(EditPredictionId, String)>>> =
2607 Arc::new(Mutex::new(Vec::new()));
2608
2609 ep_store.update(cx, |ep_store, cx| {
2610 ep_store.register_buffer(&buffer, &project, cx);
2611
2612 let settled_events = settled_events.clone();
2613 ep_store.settled_event_callback = Some(Box::new(move |id, text| {
2614 settled_events.lock().push((id, text));
2615 }));
2616 });
2617
2618 // --- Phase 1: edit in region A and enqueue prediction A ---
2619
2620 buffer.update(cx, |buffer, cx| {
2621 // Edit at the start of line 0.
2622 buffer.edit(vec![(0..0, "ADDED ")], None, cx);
2623 });
2624 cx.run_until_parked();
2625
2626 let snapshot_a = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2627
2628 // Region A: first 10 lines of the buffer.
2629 let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0));
2630 ep_store.update(cx, |ep_store, cx| {
2631 ep_store.enqueue_settled_prediction(
2632 EditPredictionId("prediction-a".into()),
2633 &project,
2634 &buffer,
2635 &snapshot_a,
2636 editable_region_a,
2637 cx,
2638 );
2639 });
2640
2641 // --- Phase 2: repeatedly edit in region A to keep it unsettled ---
2642
2643 // Let the worker process the channel message before we start advancing.
2644 cx.run_until_parked();
2645
2646 let mut region_a_edit_offset = 5;
2647 for _ in 0..3 {
2648 // Edit inside region A (not at the boundary) so `last_edit_at` is
2649 // updated before the worker's next wake.
2650 buffer.update(cx, |buffer, cx| {
2651 buffer.edit(
2652 vec![(region_a_edit_offset..region_a_edit_offset, "x")],
2653 None,
2654 cx,
2655 );
2656 });
2657 region_a_edit_offset += 1;
2658 cx.run_until_parked();
2659
2660 cx.executor()
2661 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 2);
2662 cx.run_until_parked();
2663 assert!(
2664 settled_events.lock().is_empty(),
2665 "no settled events should fire while region A is still being edited"
2666 );
2667 }
2668
2669 // Still nothing settled.
2670 assert!(settled_events.lock().is_empty());
2671
2672 // --- Phase 3: edit in distinct region B, enqueue prediction B ---
2673 // Advance a small amount so B's quiescence window starts later than A's,
2674 // but not so much that A settles (A's last edit was at the start of
2675 // iteration 3, and it needs a full Q to settle).
2676 cx.executor()
2677 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2678 cx.run_until_parked();
2679 assert!(settled_events.lock().is_empty());
2680
2681 let snapshot_b = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2682 let line_20_offset = snapshot_b.point_to_offset(Point::new(20, 0));
2683
2684 buffer.update(cx, |buffer, cx| {
2685 buffer.edit(vec![(line_20_offset..line_20_offset, "NEW ")], None, cx);
2686 });
2687 cx.run_until_parked();
2688
2689 let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2690 let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0));
2691 ep_store.update(cx, |ep_store, cx| {
2692 ep_store.enqueue_settled_prediction(
2693 EditPredictionId("prediction-b".into()),
2694 &project,
2695 &buffer,
2696 &snapshot_b2,
2697 editable_region_b,
2698 cx,
2699 );
2700 });
2701
2702 cx.run_until_parked();
2703 assert!(
2704 settled_events.lock().is_empty(),
2705 "neither prediction should have settled yet"
2706 );
2707
2708 // --- Phase 4: let enough time pass for region A to settle ---
2709 // A's last edit was at T_a (during the last loop iteration). The worker is
2710 // sleeping until T_a + Q. We advance just enough to reach that wake time
2711 // (Q/4 since we already advanced Q/4 in phase 3 on top of the loop's
2712 // 3*Q/2). At that point A has been quiet for Q and settles, but B was
2713 // enqueued only Q/4 ago and stays pending.
2714 cx.executor()
2715 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2716 cx.run_until_parked();
2717
2718 {
2719 let events = settled_events.lock().clone();
2720 assert_eq!(
2721 events.len(),
2722 1,
2723 "only prediction A should have settled, got: {events:?}"
2724 );
2725 assert_eq!(events[0].0, EditPredictionId("prediction-a".into()));
2726 }
2727
2728 // --- Phase 5: let more time pass for region B to settle ---
2729 // B's last edit was Q/4 before A settled. The worker rescheduled to
2730 // B's last_edit_at + Q, which is 3Q/4 from now.
2731 cx.executor()
2732 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE * 3 / 4);
2733 cx.run_until_parked();
2734
2735 {
2736 let events = settled_events.lock().clone();
2737 assert_eq!(
2738 events.len(),
2739 2,
2740 "both predictions should have settled, got: {events:?}"
2741 );
2742 assert_eq!(events[1].0, EditPredictionId("prediction-b".into()));
2743 }
2744}
2745
2746#[ctor::ctor]
2747fn init_logger() {
2748 zlog::init_test();
2749}