1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
3use client::{UserStore, test::FakeServer};
4use clock::FakeSystemClock;
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
8 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
9};
10use futures::{
11 AsyncReadExt, FutureExt, StreamExt,
12 channel::{mpsc, oneshot},
13};
14use gpui::App;
15use gpui::{
16 Entity, TestAppContext,
17 http_client::{FakeHttpClient, Response},
18};
19use indoc::indoc;
20use language::{Anchor, Buffer, CursorShape, Operation, Point, Selection, SelectionGoal};
21use lsp::LanguageServerId;
22use parking_lot::Mutex;
23use pretty_assertions::{assert_eq, assert_matches};
24use project::{FakeFs, Project};
25use serde_json::json;
26use settings::SettingsStore;
27use std::{path::Path, sync::Arc, time::Duration};
28use util::path;
29use uuid::Uuid;
30use zeta_prompt::ZetaPromptInput;
31
32use crate::{
33 BufferEditPrediction, EDIT_PREDICTION_SETTLED_QUIESCENCE, EditPredictionId,
34 EditPredictionStore, REJECT_REQUEST_DEBOUNCE,
35};
36
37#[gpui::test]
38async fn test_current_state(cx: &mut TestAppContext) {
39 let (ep_store, mut requests) = init_test_with_fake_client(cx);
40 let fs = FakeFs::new(cx.executor());
41 fs.insert_tree(
42 "/root",
43 json!({
44 "1.txt": "Hello!\nHow\nBye\n",
45 "2.txt": "Hola!\nComo\nAdios\n"
46 }),
47 )
48 .await;
49 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
50
51 let buffer1 = project
52 .update(cx, |project, cx| {
53 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
54 project.set_active_path(Some(path.clone()), cx);
55 project.open_buffer(path, cx)
56 })
57 .await
58 .unwrap();
59 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
60 let position = snapshot1.anchor_before(language::Point::new(1, 3));
61
62 ep_store.update(cx, |ep_store, cx| {
63 ep_store.register_project(&project, cx);
64 ep_store.register_buffer(&buffer1, &project, cx);
65 });
66
67 // Prediction for current file
68
69 ep_store.update(cx, |ep_store, cx| {
70 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
71 });
72 let (request, respond_tx) = requests.predict.next().await.unwrap();
73
74 respond_tx
75 .send(model_response(
76 &request,
77 indoc! {r"
78 --- a/root/1.txt
79 +++ b/root/1.txt
80 @@ ... @@
81 Hello!
82 -How
83 +How are you?
84 Bye
85 "},
86 ))
87 .unwrap();
88
89 cx.run_until_parked();
90
91 ep_store.update(cx, |ep_store, cx| {
92 let prediction = ep_store
93 .prediction_at(&buffer1, None, &project, cx)
94 .unwrap();
95 assert_matches!(prediction, BufferEditPrediction::Local { .. });
96 });
97
98 ep_store.update(cx, |ep_store, cx| {
99 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
100 });
101
102 // Prediction for diagnostic in another file
103
104 let diagnostic = lsp::Diagnostic {
105 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
106 severity: Some(lsp::DiagnosticSeverity::ERROR),
107 message: "Sentence is incomplete".to_string(),
108 ..Default::default()
109 };
110
111 project.update(cx, |project, cx| {
112 project.lsp_store().update(cx, |lsp_store, cx| {
113 lsp_store
114 .update_diagnostics(
115 LanguageServerId(0),
116 lsp::PublishDiagnosticsParams {
117 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
118 diagnostics: vec![diagnostic],
119 version: None,
120 },
121 None,
122 language::DiagnosticSourceKind::Pushed,
123 &[],
124 cx,
125 )
126 .unwrap();
127 });
128 });
129
130 let (request, respond_tx) = requests.predict.next().await.unwrap();
131 respond_tx
132 .send(model_response(
133 &request,
134 indoc! {r#"
135 --- a/root/2.txt
136 +++ b/root/2.txt
137 @@ ... @@
138 Hola!
139 -Como
140 +Como estas?
141 Adios
142 "#},
143 ))
144 .unwrap();
145 cx.run_until_parked();
146
147 ep_store.update(cx, |ep_store, cx| {
148 let prediction = ep_store
149 .prediction_at(&buffer1, None, &project, cx)
150 .unwrap();
151 assert_matches!(
152 prediction,
153 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
154 );
155 });
156
157 let buffer2 = project
158 .update(cx, |project, cx| {
159 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
160 project.open_buffer(path, cx)
161 })
162 .await
163 .unwrap();
164
165 ep_store.update(cx, |ep_store, cx| {
166 let prediction = ep_store
167 .prediction_at(&buffer2, None, &project, cx)
168 .unwrap();
169 assert_matches!(prediction, BufferEditPrediction::Local { .. });
170 });
171}
172
173#[gpui::test]
174async fn test_simple_request(cx: &mut TestAppContext) {
175 let (ep_store, mut requests) = init_test_with_fake_client(cx);
176 let fs = FakeFs::new(cx.executor());
177 fs.insert_tree(
178 "/root",
179 json!({
180 "foo.md": "Hello!\nHow\nBye\n"
181 }),
182 )
183 .await;
184 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
185
186 let buffer = project
187 .update(cx, |project, cx| {
188 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
189 project.open_buffer(path, cx)
190 })
191 .await
192 .unwrap();
193 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
194 let position = snapshot.anchor_before(language::Point::new(1, 3));
195
196 let prediction_task = ep_store.update(cx, |ep_store, cx| {
197 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
198 });
199
200 let (request, respond_tx) = requests.predict.next().await.unwrap();
201
202 // TODO Put back when we have a structured request again
203 // assert_eq!(
204 // request.excerpt_path.as_ref(),
205 // Path::new(path!("root/foo.md"))
206 // );
207 // assert_eq!(
208 // request.cursor_point,
209 // Point {
210 // line: Line(1),
211 // column: 3
212 // }
213 // );
214
215 respond_tx
216 .send(model_response(
217 &request,
218 indoc! { r"
219 --- a/root/foo.md
220 +++ b/root/foo.md
221 @@ ... @@
222 Hello!
223 -How
224 +How are you?
225 Bye
226 "},
227 ))
228 .unwrap();
229
230 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
231
232 assert_eq!(prediction.edits.len(), 1);
233 assert_eq!(
234 prediction.edits[0].0.to_point(&snapshot).start,
235 language::Point::new(1, 3)
236 );
237 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
238}
239
240#[gpui::test]
241async fn test_request_events(cx: &mut TestAppContext) {
242 let (ep_store, mut requests) = init_test_with_fake_client(cx);
243 let fs = FakeFs::new(cx.executor());
244 fs.insert_tree(
245 "/root",
246 json!({
247 "foo.md": "Hello!\n\nBye\n"
248 }),
249 )
250 .await;
251 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
252
253 let buffer = project
254 .update(cx, |project, cx| {
255 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
256 project.open_buffer(path, cx)
257 })
258 .await
259 .unwrap();
260
261 ep_store.update(cx, |ep_store, cx| {
262 ep_store.register_buffer(&buffer, &project, cx);
263 });
264
265 buffer.update(cx, |buffer, cx| {
266 buffer.edit(vec![(7..7, "How")], None, cx);
267 });
268
269 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
270 let position = snapshot.anchor_before(language::Point::new(1, 3));
271
272 let prediction_task = ep_store.update(cx, |ep_store, cx| {
273 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
274 });
275
276 let (request, respond_tx) = requests.predict.next().await.unwrap();
277
278 let prompt = prompt_from_request(&request);
279 assert!(
280 prompt.contains(indoc! {"
281 --- a/root/foo.md
282 +++ b/root/foo.md
283 @@ -1,3 +1,3 @@
284 Hello!
285 -
286 +How
287 Bye
288 "}),
289 "{prompt}"
290 );
291
292 respond_tx
293 .send(model_response(
294 &request,
295 indoc! {r#"
296 --- a/root/foo.md
297 +++ b/root/foo.md
298 @@ ... @@
299 Hello!
300 -How
301 +How are you?
302 Bye
303 "#},
304 ))
305 .unwrap();
306
307 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
308
309 assert_eq!(prediction.edits.len(), 1);
310 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
311}
312
313#[gpui::test]
314async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
315 let (ep_store, _requests) = init_test_with_fake_client(cx);
316 let fs = FakeFs::new(cx.executor());
317 fs.insert_tree(
318 "/root",
319 json!({
320 "foo.md": "Hello!\n\nBye\n"
321 }),
322 )
323 .await;
324 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
325
326 let buffer = project
327 .update(cx, |project, cx| {
328 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
329 project.open_buffer(path, cx)
330 })
331 .await
332 .unwrap();
333
334 ep_store.update(cx, |ep_store, cx| {
335 ep_store.register_buffer(&buffer, &project, cx);
336 });
337
338 // First burst: insert "How"
339 buffer.update(cx, |buffer, cx| {
340 buffer.edit(vec![(7..7, "How")], None, cx);
341 });
342
343 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
344 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
345 cx.run_until_parked();
346
347 // Second burst: append " are you?" immediately after "How" on the same line.
348 //
349 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
350 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
351 buffer.update(cx, |buffer, cx| {
352 buffer.edit(vec![(10..10, " are you?")], None, cx);
353 });
354
355 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
356 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
357 buffer.update(cx, |buffer, cx| {
358 buffer.edit(vec![(19..19, "!")], None, cx);
359 });
360
361 // With time-based splitting, there are two distinct events.
362 let events = ep_store.update(cx, |ep_store, cx| {
363 ep_store.edit_history_for_project(&project, cx)
364 });
365 assert_eq!(events.len(), 2);
366 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
367 assert_eq!(
368 diff.as_str(),
369 indoc! {"
370 @@ -1,3 +1,3 @@
371 Hello!
372 -
373 +How
374 Bye
375 "}
376 );
377
378 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
379 assert_eq!(
380 diff.as_str(),
381 indoc! {"
382 @@ -1,3 +1,3 @@
383 Hello!
384 -How
385 +How are you?!
386 Bye
387 "}
388 );
389}
390
391#[gpui::test]
392async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
393 let (ep_store, _requests) = init_test_with_fake_client(cx);
394 let fs = FakeFs::new(cx.executor());
395
396 // Create a file with 30 lines to test line-based coalescing
397 let content = (1..=30)
398 .map(|i| format!("Line {}\n", i))
399 .collect::<String>();
400 fs.insert_tree(
401 "/root",
402 json!({
403 "foo.md": content
404 }),
405 )
406 .await;
407 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
408
409 let buffer = project
410 .update(cx, |project, cx| {
411 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
412 project.open_buffer(path, cx)
413 })
414 .await
415 .unwrap();
416
417 ep_store.update(cx, |ep_store, cx| {
418 ep_store.register_buffer(&buffer, &project, cx);
419 });
420
421 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
422 buffer.update(cx, |buffer, cx| {
423 let start = Point::new(10, 0).to_offset(buffer);
424 let end = Point::new(13, 0).to_offset(buffer);
425 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
426 });
427
428 let events = ep_store.update(cx, |ep_store, cx| {
429 ep_store.edit_history_for_project(&project, cx)
430 });
431 assert_eq!(
432 render_events(&events),
433 indoc! {"
434 @@ -8,9 +8,8 @@
435 Line 8
436 Line 9
437 Line 10
438 -Line 11
439 -Line 12
440 -Line 13
441 +Middle A
442 +Middle B
443 Line 14
444 Line 15
445 Line 16
446 "},
447 "After first edit"
448 );
449
450 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
451 // This tests that coalescing considers the START of the existing range
452 buffer.update(cx, |buffer, cx| {
453 let offset = Point::new(5, 0).to_offset(buffer);
454 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
455 });
456
457 let events = ep_store.update(cx, |ep_store, cx| {
458 ep_store.edit_history_for_project(&project, cx)
459 });
460 assert_eq!(
461 render_events(&events),
462 indoc! {"
463 @@ -3,14 +3,14 @@
464 Line 3
465 Line 4
466 Line 5
467 +Above
468 Line 6
469 Line 7
470 Line 8
471 Line 9
472 Line 10
473 -Line 11
474 -Line 12
475 -Line 13
476 +Middle A
477 +Middle B
478 Line 14
479 Line 15
480 Line 16
481 "},
482 "After inserting above (should coalesce)"
483 );
484
485 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
486 // This tests that coalescing considers the END of the existing range
487 buffer.update(cx, |buffer, cx| {
488 let offset = Point::new(14, 0).to_offset(buffer);
489 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
490 });
491
492 let events = ep_store.update(cx, |ep_store, cx| {
493 ep_store.edit_history_for_project(&project, cx)
494 });
495 assert_eq!(
496 render_events(&events),
497 indoc! {"
498 @@ -3,15 +3,16 @@
499 Line 3
500 Line 4
501 Line 5
502 +Above
503 Line 6
504 Line 7
505 Line 8
506 Line 9
507 Line 10
508 -Line 11
509 -Line 12
510 -Line 13
511 +Middle A
512 +Middle B
513 Line 14
514 +Below
515 Line 15
516 Line 16
517 Line 17
518 "},
519 "After inserting below (should coalesce)"
520 );
521
522 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
523 // This should NOT coalesce - creates a new event
524 buffer.update(cx, |buffer, cx| {
525 let offset = Point::new(25, 0).to_offset(buffer);
526 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
527 });
528
529 let events = ep_store.update(cx, |ep_store, cx| {
530 ep_store.edit_history_for_project(&project, cx)
531 });
532 assert_eq!(
533 render_events(&events),
534 indoc! {"
535 @@ -3,15 +3,16 @@
536 Line 3
537 Line 4
538 Line 5
539 +Above
540 Line 6
541 Line 7
542 Line 8
543 Line 9
544 Line 10
545 -Line 11
546 -Line 12
547 -Line 13
548 +Middle A
549 +Middle B
550 Line 14
551 +Below
552 Line 15
553 Line 16
554 Line 17
555
556 ---
557 @@ -23,6 +23,7 @@
558 Line 22
559 Line 23
560 Line 24
561 +Far below
562 Line 25
563 Line 26
564 Line 27
565 "},
566 "After inserting far below (should NOT coalesce)"
567 );
568}
569
570fn render_events(events: &[StoredEvent]) -> String {
571 events
572 .iter()
573 .map(|e| {
574 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
575 diff.as_str()
576 })
577 .collect::<Vec<_>>()
578 .join("\n---\n")
579}
580
581fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
582 events
583 .iter()
584 .map(|e| {
585 let zeta_prompt::Event::BufferChange {
586 diff, predicted, ..
587 } = e.event.as_ref();
588 let prefix = if *predicted { "predicted" } else { "manual" };
589 format!("{}\n{}", prefix, diff)
590 })
591 .collect()
592}
593
594#[gpui::test]
595async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
596 let (ep_store, _requests) = init_test_with_fake_client(cx);
597 let fs = FakeFs::new(cx.executor());
598 fs.insert_tree(
599 "/root",
600 json!({
601 "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
602 }),
603 )
604 .await;
605 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
606
607 let buffer = project
608 .update(cx, |project, cx| {
609 let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
610 project.open_buffer(path, cx)
611 })
612 .await
613 .unwrap();
614
615 ep_store.update(cx, |ep_store, cx| {
616 ep_store.register_buffer(&buffer, &project, cx);
617 });
618
619 // Case 1: Manual edits have `predicted` set to false.
620 buffer.update(cx, |buffer, cx| {
621 buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
622 });
623
624 let events = ep_store.update(cx, |ep_store, cx| {
625 ep_store.edit_history_for_project(&project, cx)
626 });
627
628 assert_eq!(
629 render_events_with_predicted(&events),
630 vec![indoc! {"
631 manual
632 @@ -1,4 +1,4 @@
633 -line 0
634 +LINE ZERO
635 line 1
636 line 2
637 line 3
638 "}]
639 );
640
641 // Case 2: Multiple successive manual edits near each other are merged into one
642 // event with `predicted` set to false.
643 buffer.update(cx, |buffer, cx| {
644 let offset = Point::new(1, 0).to_offset(buffer);
645 let end = Point::new(1, 6).to_offset(buffer);
646 buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
647 });
648
649 let events = ep_store.update(cx, |ep_store, cx| {
650 ep_store.edit_history_for_project(&project, cx)
651 });
652 assert_eq!(
653 render_events_with_predicted(&events),
654 vec![indoc! {"
655 manual
656 @@ -1,5 +1,5 @@
657 -line 0
658 -line 1
659 +LINE ZERO
660 +LINE ONE
661 line 2
662 line 3
663 line 4
664 "}]
665 );
666
667 // Case 3: Accepted predictions have `predicted` set to true.
668 // Case 5: A manual edit that follows a predicted edit is not merged with the
669 // predicted edit, even if it is nearby.
670 ep_store.update(cx, |ep_store, cx| {
671 buffer.update(cx, |buffer, cx| {
672 let offset = Point::new(2, 0).to_offset(buffer);
673 let end = Point::new(2, 6).to_offset(buffer);
674 buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
675 });
676 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
677 });
678
679 let events = ep_store.update(cx, |ep_store, cx| {
680 ep_store.edit_history_for_project(&project, cx)
681 });
682 assert_eq!(
683 render_events_with_predicted(&events),
684 vec![
685 indoc! {"
686 manual
687 @@ -1,5 +1,5 @@
688 -line 0
689 -line 1
690 +LINE ZERO
691 +LINE ONE
692 line 2
693 line 3
694 line 4
695 "},
696 indoc! {"
697 predicted
698 @@ -1,6 +1,6 @@
699 LINE ZERO
700 LINE ONE
701 -line 2
702 +LINE TWO
703 line 3
704 line 4
705 line 5
706 "}
707 ]
708 );
709
710 // Case 4: Multiple successive accepted predictions near each other are merged
711 // into one event with `predicted` set to true.
712 ep_store.update(cx, |ep_store, cx| {
713 buffer.update(cx, |buffer, cx| {
714 let offset = Point::new(3, 0).to_offset(buffer);
715 let end = Point::new(3, 6).to_offset(buffer);
716 buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
717 });
718 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
719 });
720
721 let events = ep_store.update(cx, |ep_store, cx| {
722 ep_store.edit_history_for_project(&project, cx)
723 });
724 assert_eq!(
725 render_events_with_predicted(&events),
726 vec![
727 indoc! {"
728 manual
729 @@ -1,5 +1,5 @@
730 -line 0
731 -line 1
732 +LINE ZERO
733 +LINE ONE
734 line 2
735 line 3
736 line 4
737 "},
738 indoc! {"
739 predicted
740 @@ -1,7 +1,7 @@
741 LINE ZERO
742 LINE ONE
743 -line 2
744 -line 3
745 +LINE TWO
746 +LINE THREE
747 line 4
748 line 5
749 line 6
750 "}
751 ]
752 );
753
754 // Case 5 (continued): A manual edit that follows a predicted edit is not merged
755 // with the predicted edit, even if it is nearby.
756 buffer.update(cx, |buffer, cx| {
757 let offset = Point::new(4, 0).to_offset(buffer);
758 let end = Point::new(4, 6).to_offset(buffer);
759 buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
760 });
761
762 let events = ep_store.update(cx, |ep_store, cx| {
763 ep_store.edit_history_for_project(&project, cx)
764 });
765 assert_eq!(
766 render_events_with_predicted(&events),
767 vec![
768 indoc! {"
769 manual
770 @@ -1,5 +1,5 @@
771 -line 0
772 -line 1
773 +LINE ZERO
774 +LINE ONE
775 line 2
776 line 3
777 line 4
778 "},
779 indoc! {"
780 predicted
781 @@ -1,7 +1,7 @@
782 LINE ZERO
783 LINE ONE
784 -line 2
785 -line 3
786 +LINE TWO
787 +LINE THREE
788 line 4
789 line 5
790 line 6
791 "},
792 indoc! {"
793 manual
794 @@ -2,7 +2,7 @@
795 LINE ONE
796 LINE TWO
797 LINE THREE
798 -line 4
799 +LINE FOUR
800 line 5
801 line 6
802 line 7
803 "}
804 ]
805 );
806
807 // Case 6: If we then perform a manual edit at a *different* location (more than
808 // 8 lines away), then the edits at the prior location can be merged with each
809 // other, even if some are predicted and some are not. `predicted` means all
810 // constituent edits were predicted.
811 buffer.update(cx, |buffer, cx| {
812 let offset = Point::new(14, 0).to_offset(buffer);
813 let end = Point::new(14, 7).to_offset(buffer);
814 buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
815 });
816
817 let events = ep_store.update(cx, |ep_store, cx| {
818 ep_store.edit_history_for_project(&project, cx)
819 });
820 assert_eq!(
821 render_events_with_predicted(&events),
822 vec![
823 indoc! {"
824 manual
825 @@ -1,8 +1,8 @@
826 -line 0
827 -line 1
828 -line 2
829 -line 3
830 -line 4
831 +LINE ZERO
832 +LINE ONE
833 +LINE TWO
834 +LINE THREE
835 +LINE FOUR
836 line 5
837 line 6
838 line 7
839 "},
840 indoc! {"
841 manual
842 @@ -12,4 +12,4 @@
843 line 11
844 line 12
845 line 13
846 -line 14
847 +LINE FOURTEEN
848 "}
849 ]
850 );
851}
852
853#[gpui::test]
854async fn test_empty_prediction(cx: &mut TestAppContext) {
855 let (ep_store, mut requests) = init_test_with_fake_client(cx);
856 let fs = FakeFs::new(cx.executor());
857 fs.insert_tree(
858 "/root",
859 json!({
860 "foo.md": "Hello!\nHow\nBye\n"
861 }),
862 )
863 .await;
864 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
865
866 let buffer = project
867 .update(cx, |project, cx| {
868 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
869 project.open_buffer(path, cx)
870 })
871 .await
872 .unwrap();
873 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
874 let position = snapshot.anchor_before(language::Point::new(1, 3));
875
876 ep_store.update(cx, |ep_store, cx| {
877 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
878 });
879
880 let (request, respond_tx) = requests.predict.next().await.unwrap();
881 let response = model_response(&request, "");
882 let id = response.request_id.clone();
883 respond_tx.send(response).unwrap();
884
885 cx.run_until_parked();
886
887 ep_store.update(cx, |ep_store, cx| {
888 assert!(
889 ep_store
890 .prediction_at(&buffer, None, &project, cx)
891 .is_none()
892 );
893 });
894
895 // prediction is reported as rejected
896 let (reject_request, _) = requests.reject.next().await.unwrap();
897
898 assert_eq!(
899 &reject_request.rejections,
900 &[EditPredictionRejection {
901 request_id: id,
902 reason: EditPredictionRejectReason::Empty,
903 was_shown: false,
904 model_version: None,
905 }]
906 );
907}
908
909#[gpui::test]
910async fn test_interpolated_empty(cx: &mut TestAppContext) {
911 let (ep_store, mut requests) = init_test_with_fake_client(cx);
912 let fs = FakeFs::new(cx.executor());
913 fs.insert_tree(
914 "/root",
915 json!({
916 "foo.md": "Hello!\nHow\nBye\n"
917 }),
918 )
919 .await;
920 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
921
922 let buffer = project
923 .update(cx, |project, cx| {
924 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
925 project.open_buffer(path, cx)
926 })
927 .await
928 .unwrap();
929 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
930 let position = snapshot.anchor_before(language::Point::new(1, 3));
931
932 ep_store.update(cx, |ep_store, cx| {
933 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
934 });
935
936 let (request, respond_tx) = requests.predict.next().await.unwrap();
937
938 buffer.update(cx, |buffer, cx| {
939 buffer.set_text("Hello!\nHow are you?\nBye", cx);
940 });
941
942 let response = model_response(&request, SIMPLE_DIFF);
943 let id = response.request_id.clone();
944 respond_tx.send(response).unwrap();
945
946 cx.run_until_parked();
947
948 ep_store.update(cx, |ep_store, cx| {
949 assert!(
950 ep_store
951 .prediction_at(&buffer, None, &project, cx)
952 .is_none()
953 );
954 });
955
956 // prediction is reported as rejected
957 let (reject_request, _) = requests.reject.next().await.unwrap();
958
959 assert_eq!(
960 &reject_request.rejections,
961 &[EditPredictionRejection {
962 request_id: id,
963 reason: EditPredictionRejectReason::InterpolatedEmpty,
964 was_shown: false,
965 model_version: None,
966 }]
967 );
968}
969
970const SIMPLE_DIFF: &str = indoc! { r"
971 --- a/root/foo.md
972 +++ b/root/foo.md
973 @@ ... @@
974 Hello!
975 -How
976 +How are you?
977 Bye
978"};
979
980#[gpui::test]
981async fn test_replace_current(cx: &mut TestAppContext) {
982 let (ep_store, mut requests) = init_test_with_fake_client(cx);
983 let fs = FakeFs::new(cx.executor());
984 fs.insert_tree(
985 "/root",
986 json!({
987 "foo.md": "Hello!\nHow\nBye\n"
988 }),
989 )
990 .await;
991 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
992
993 let buffer = project
994 .update(cx, |project, cx| {
995 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
996 project.open_buffer(path, cx)
997 })
998 .await
999 .unwrap();
1000 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1001 let position = snapshot.anchor_before(language::Point::new(1, 3));
1002
1003 ep_store.update(cx, |ep_store, cx| {
1004 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1005 });
1006
1007 let (request, respond_tx) = requests.predict.next().await.unwrap();
1008 let first_response = model_response(&request, SIMPLE_DIFF);
1009 let first_id = first_response.request_id.clone();
1010 respond_tx.send(first_response).unwrap();
1011
1012 cx.run_until_parked();
1013
1014 ep_store.update(cx, |ep_store, cx| {
1015 assert_eq!(
1016 ep_store
1017 .prediction_at(&buffer, None, &project, cx)
1018 .unwrap()
1019 .id
1020 .0,
1021 first_id
1022 );
1023 });
1024
1025 // a second request is triggered
1026 ep_store.update(cx, |ep_store, cx| {
1027 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1028 });
1029
1030 let (request, respond_tx) = requests.predict.next().await.unwrap();
1031 let second_response = model_response(&request, SIMPLE_DIFF);
1032 let second_id = second_response.request_id.clone();
1033 respond_tx.send(second_response).unwrap();
1034
1035 cx.run_until_parked();
1036
1037 ep_store.update(cx, |ep_store, cx| {
1038 // second replaces first
1039 assert_eq!(
1040 ep_store
1041 .prediction_at(&buffer, None, &project, cx)
1042 .unwrap()
1043 .id
1044 .0,
1045 second_id
1046 );
1047 });
1048
1049 // first is reported as replaced
1050 let (reject_request, _) = requests.reject.next().await.unwrap();
1051
1052 assert_eq!(
1053 &reject_request.rejections,
1054 &[EditPredictionRejection {
1055 request_id: first_id,
1056 reason: EditPredictionRejectReason::Replaced,
1057 was_shown: false,
1058 model_version: None,
1059 }]
1060 );
1061}
1062
1063#[gpui::test]
1064async fn test_current_preferred(cx: &mut TestAppContext) {
1065 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1066 let fs = FakeFs::new(cx.executor());
1067 fs.insert_tree(
1068 "/root",
1069 json!({
1070 "foo.md": "Hello!\nHow\nBye\n"
1071 }),
1072 )
1073 .await;
1074 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1075
1076 let buffer = project
1077 .update(cx, |project, cx| {
1078 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1079 project.open_buffer(path, cx)
1080 })
1081 .await
1082 .unwrap();
1083 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1084 let position = snapshot.anchor_before(language::Point::new(1, 3));
1085
1086 ep_store.update(cx, |ep_store, cx| {
1087 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1088 });
1089
1090 let (request, respond_tx) = requests.predict.next().await.unwrap();
1091 let first_response = model_response(&request, SIMPLE_DIFF);
1092 let first_id = first_response.request_id.clone();
1093 respond_tx.send(first_response).unwrap();
1094
1095 cx.run_until_parked();
1096
1097 ep_store.update(cx, |ep_store, cx| {
1098 assert_eq!(
1099 ep_store
1100 .prediction_at(&buffer, None, &project, cx)
1101 .unwrap()
1102 .id
1103 .0,
1104 first_id
1105 );
1106 });
1107
1108 // a second request is triggered
1109 ep_store.update(cx, |ep_store, cx| {
1110 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1111 });
1112
1113 let (request, respond_tx) = requests.predict.next().await.unwrap();
1114 // worse than current prediction
1115 let second_response = model_response(
1116 &request,
1117 indoc! { r"
1118 --- a/root/foo.md
1119 +++ b/root/foo.md
1120 @@ ... @@
1121 Hello!
1122 -How
1123 +How are
1124 Bye
1125 "},
1126 );
1127 let second_id = second_response.request_id.clone();
1128 respond_tx.send(second_response).unwrap();
1129
1130 cx.run_until_parked();
1131
1132 ep_store.update(cx, |ep_store, cx| {
1133 // first is preferred over second
1134 assert_eq!(
1135 ep_store
1136 .prediction_at(&buffer, None, &project, cx)
1137 .unwrap()
1138 .id
1139 .0,
1140 first_id
1141 );
1142 });
1143
1144 // second is reported as rejected
1145 let (reject_request, _) = requests.reject.next().await.unwrap();
1146
1147 assert_eq!(
1148 &reject_request.rejections,
1149 &[EditPredictionRejection {
1150 request_id: second_id,
1151 reason: EditPredictionRejectReason::CurrentPreferred,
1152 was_shown: false,
1153 model_version: None,
1154 }]
1155 );
1156}
1157
1158#[gpui::test]
1159async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1160 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1161 let fs = FakeFs::new(cx.executor());
1162 fs.insert_tree(
1163 "/root",
1164 json!({
1165 "foo.md": "Hello!\nHow\nBye\n"
1166 }),
1167 )
1168 .await;
1169 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1170
1171 let buffer = project
1172 .update(cx, |project, cx| {
1173 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1174 project.open_buffer(path, cx)
1175 })
1176 .await
1177 .unwrap();
1178 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1179 let position = snapshot.anchor_before(language::Point::new(1, 3));
1180
1181 // start two refresh tasks
1182 ep_store.update(cx, |ep_store, cx| {
1183 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1184 });
1185
1186 let (request1, respond_first) = requests.predict.next().await.unwrap();
1187
1188 ep_store.update(cx, |ep_store, cx| {
1189 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1190 });
1191
1192 let (request, respond_second) = requests.predict.next().await.unwrap();
1193
1194 // wait for throttle
1195 cx.run_until_parked();
1196
1197 // second responds first
1198 let second_response = model_response(&request, SIMPLE_DIFF);
1199 let second_id = second_response.request_id.clone();
1200 respond_second.send(second_response).unwrap();
1201
1202 cx.run_until_parked();
1203
1204 ep_store.update(cx, |ep_store, cx| {
1205 // current prediction is second
1206 assert_eq!(
1207 ep_store
1208 .prediction_at(&buffer, None, &project, cx)
1209 .unwrap()
1210 .id
1211 .0,
1212 second_id
1213 );
1214 });
1215
1216 let first_response = model_response(&request1, SIMPLE_DIFF);
1217 let first_id = first_response.request_id.clone();
1218 respond_first.send(first_response).unwrap();
1219
1220 cx.run_until_parked();
1221
1222 ep_store.update(cx, |ep_store, cx| {
1223 // current prediction is still second, since first was cancelled
1224 assert_eq!(
1225 ep_store
1226 .prediction_at(&buffer, None, &project, cx)
1227 .unwrap()
1228 .id
1229 .0,
1230 second_id
1231 );
1232 });
1233
1234 // first is reported as rejected
1235 let (reject_request, _) = requests.reject.next().await.unwrap();
1236
1237 cx.run_until_parked();
1238
1239 assert_eq!(
1240 &reject_request.rejections,
1241 &[EditPredictionRejection {
1242 request_id: first_id,
1243 reason: EditPredictionRejectReason::Canceled,
1244 was_shown: false,
1245 model_version: None,
1246 }]
1247 );
1248}
1249
1250#[gpui::test]
1251async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1252 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1253 let fs = FakeFs::new(cx.executor());
1254 fs.insert_tree(
1255 "/root",
1256 json!({
1257 "foo.md": "Hello!\nHow\nBye\n"
1258 }),
1259 )
1260 .await;
1261 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1262
1263 let buffer = project
1264 .update(cx, |project, cx| {
1265 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1266 project.open_buffer(path, cx)
1267 })
1268 .await
1269 .unwrap();
1270 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1271 let position = snapshot.anchor_before(language::Point::new(1, 3));
1272
1273 // start two refresh tasks
1274 ep_store.update(cx, |ep_store, cx| {
1275 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1276 });
1277
1278 let (request1, respond_first) = requests.predict.next().await.unwrap();
1279
1280 ep_store.update(cx, |ep_store, cx| {
1281 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1282 });
1283
1284 let (request2, respond_second) = requests.predict.next().await.unwrap();
1285
1286 // wait for throttle, so requests are sent
1287 cx.run_until_parked();
1288
1289 ep_store.update(cx, |ep_store, cx| {
1290 // start a third request
1291 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1292
1293 // 2 are pending, so 2nd is cancelled
1294 assert_eq!(
1295 ep_store
1296 .get_or_init_project(&project, cx)
1297 .cancelled_predictions
1298 .iter()
1299 .copied()
1300 .collect::<Vec<_>>(),
1301 [1]
1302 );
1303 });
1304
1305 // wait for throttle
1306 cx.run_until_parked();
1307
1308 let (request3, respond_third) = requests.predict.next().await.unwrap();
1309
1310 let first_response = model_response(&request1, SIMPLE_DIFF);
1311 let first_id = first_response.request_id.clone();
1312 respond_first.send(first_response).unwrap();
1313
1314 cx.run_until_parked();
1315
1316 ep_store.update(cx, |ep_store, cx| {
1317 // current prediction is first
1318 assert_eq!(
1319 ep_store
1320 .prediction_at(&buffer, None, &project, cx)
1321 .unwrap()
1322 .id
1323 .0,
1324 first_id
1325 );
1326 });
1327
1328 let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1329 let cancelled_id = cancelled_response.request_id.clone();
1330 respond_second.send(cancelled_response).unwrap();
1331
1332 cx.run_until_parked();
1333
1334 ep_store.update(cx, |ep_store, cx| {
1335 // current prediction is still first, since second was cancelled
1336 assert_eq!(
1337 ep_store
1338 .prediction_at(&buffer, None, &project, cx)
1339 .unwrap()
1340 .id
1341 .0,
1342 first_id
1343 );
1344 });
1345
1346 let third_response = model_response(&request3, SIMPLE_DIFF);
1347 let third_response_id = third_response.request_id.clone();
1348 respond_third.send(third_response).unwrap();
1349
1350 cx.run_until_parked();
1351
1352 ep_store.update(cx, |ep_store, cx| {
1353 // third completes and replaces first
1354 assert_eq!(
1355 ep_store
1356 .prediction_at(&buffer, None, &project, cx)
1357 .unwrap()
1358 .id
1359 .0,
1360 third_response_id
1361 );
1362 });
1363
1364 // second is reported as rejected
1365 let (reject_request, _) = requests.reject.next().await.unwrap();
1366
1367 cx.run_until_parked();
1368
1369 assert_eq!(
1370 &reject_request.rejections,
1371 &[
1372 EditPredictionRejection {
1373 request_id: cancelled_id,
1374 reason: EditPredictionRejectReason::Canceled,
1375 was_shown: false,
1376 model_version: None,
1377 },
1378 EditPredictionRejection {
1379 request_id: first_id,
1380 reason: EditPredictionRejectReason::Replaced,
1381 was_shown: false,
1382 model_version: None,
1383 }
1384 ]
1385 );
1386}
1387
1388#[gpui::test]
1389async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
1390 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1391
1392 let fs = FakeFs::new(cx.executor());
1393 fs.insert_tree(
1394 "/root",
1395 json!({
1396 "foo.md": "Hello!\nHow\nBye\n",
1397 "bar.md": "Hola!\nComo\nAdios\n"
1398 }),
1399 )
1400 .await;
1401 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1402
1403 let buffer = project
1404 .update(cx, |project, cx| {
1405 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1406 project.set_active_path(Some(path.clone()), cx);
1407 project.open_buffer(path, cx)
1408 })
1409 .await
1410 .unwrap();
1411 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1412 let position = snapshot.anchor_before(language::Point::new(1, 3));
1413
1414 ep_store.update(cx, |ep_store, cx| {
1415 ep_store.register_project(&project, cx);
1416 ep_store.register_buffer(&buffer, &project, cx);
1417 });
1418
1419 // First edit request - no prior edit, so not throttled.
1420 ep_store.update(cx, |ep_store, cx| {
1421 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1422 });
1423 let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
1424 edit_response_tx.send(empty_response()).unwrap();
1425 cx.run_until_parked();
1426
1427 let diagnostic = lsp::Diagnostic {
1428 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1429 severity: Some(lsp::DiagnosticSeverity::ERROR),
1430 message: "Sentence is incomplete".to_string(),
1431 ..Default::default()
1432 };
1433
1434 // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
1435 project.update(cx, |project, cx| {
1436 project.lsp_store().update(cx, |lsp_store, cx| {
1437 lsp_store
1438 .update_diagnostics(
1439 LanguageServerId(0),
1440 lsp::PublishDiagnosticsParams {
1441 uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
1442 diagnostics: vec![diagnostic],
1443 version: None,
1444 },
1445 None,
1446 language::DiagnosticSourceKind::Pushed,
1447 &[],
1448 cx,
1449 )
1450 .unwrap();
1451 });
1452 });
1453 let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
1454 jump_response_tx.send(empty_response()).unwrap();
1455 cx.run_until_parked();
1456
1457 // Second edit request - should be throttled by the first edit.
1458 ep_store.update(cx, |ep_store, cx| {
1459 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1460 });
1461 assert_no_predict_request_ready(&mut requests.predict);
1462
1463 // Second jump request - should be throttled by the first jump.
1464 ep_store.update(cx, |ep_store, cx| {
1465 ep_store.refresh_prediction_from_diagnostics(
1466 project.clone(),
1467 DiagnosticSearchScope::Global,
1468 cx,
1469 );
1470 });
1471 assert_no_predict_request_ready(&mut requests.predict);
1472
1473 // Wait for both throttles to expire.
1474 cx.background_executor
1475 .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
1476 cx.background_executor.run_until_parked();
1477 cx.run_until_parked();
1478
1479 // Both requests should now go through.
1480 let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
1481 response_tx_1.send(empty_response()).unwrap();
1482 cx.run_until_parked();
1483
1484 let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
1485 response_tx_2.send(empty_response()).unwrap();
1486 cx.run_until_parked();
1487}
1488
1489#[gpui::test]
1490async fn test_same_frame_duplicate_requests_deduplicated(cx: &mut TestAppContext) {
1491 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1492 let fs = FakeFs::new(cx.executor());
1493 fs.insert_tree(
1494 "/root",
1495 json!({
1496 "foo.md": "Hello!\nHow\nBye\n"
1497 }),
1498 )
1499 .await;
1500 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1501
1502 let buffer = project
1503 .update(cx, |project, cx| {
1504 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1505 project.open_buffer(path, cx)
1506 })
1507 .await
1508 .unwrap();
1509 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1510 let position = snapshot.anchor_before(language::Point::new(1, 3));
1511
1512 // Enqueue two refresh calls in the same synchronous frame (no yielding).
1513 // Both `cx.spawn` tasks are created before either executes, so they both
1514 // capture the same `proceed_count_at_enqueue`. Only the first task should
1515 // pass the deduplication gate; the second should be skipped.
1516 ep_store.update(cx, |ep_store, cx| {
1517 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1518 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1519 });
1520
1521 // Let both spawned tasks run to completion (including any throttle waits).
1522 cx.run_until_parked();
1523
1524 // Exactly one prediction request should have been sent.
1525 let (request, respond_tx) = requests.predict.next().await.unwrap();
1526 respond_tx
1527 .send(model_response(&request, SIMPLE_DIFF))
1528 .unwrap();
1529 cx.run_until_parked();
1530
1531 // No second request should be pending.
1532 assert_no_predict_request_ready(&mut requests.predict);
1533}
1534
1535#[gpui::test]
1536async fn test_rejections_flushing(cx: &mut TestAppContext) {
1537 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1538
1539 ep_store.update(cx, |ep_store, cx| {
1540 ep_store.reject_prediction(
1541 EditPredictionId("test-1".into()),
1542 EditPredictionRejectReason::Discarded,
1543 false,
1544 None,
1545 cx,
1546 );
1547 ep_store.reject_prediction(
1548 EditPredictionId("test-2".into()),
1549 EditPredictionRejectReason::Canceled,
1550 true,
1551 None,
1552 cx,
1553 );
1554 });
1555
1556 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1557 cx.run_until_parked();
1558
1559 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1560 respond_tx.send(()).unwrap();
1561
1562 // batched
1563 assert_eq!(reject_request.rejections.len(), 2);
1564 assert_eq!(
1565 reject_request.rejections[0],
1566 EditPredictionRejection {
1567 request_id: "test-1".to_string(),
1568 reason: EditPredictionRejectReason::Discarded,
1569 was_shown: false,
1570 model_version: None,
1571 }
1572 );
1573 assert_eq!(
1574 reject_request.rejections[1],
1575 EditPredictionRejection {
1576 request_id: "test-2".to_string(),
1577 reason: EditPredictionRejectReason::Canceled,
1578 was_shown: true,
1579 model_version: None,
1580 }
1581 );
1582
1583 // Reaching batch size limit sends without debounce
1584 ep_store.update(cx, |ep_store, cx| {
1585 for i in 0..70 {
1586 ep_store.reject_prediction(
1587 EditPredictionId(format!("batch-{}", i).into()),
1588 EditPredictionRejectReason::Discarded,
1589 false,
1590 None,
1591 cx,
1592 );
1593 }
1594 });
1595
1596 // First MAX/2 items are sent immediately
1597 cx.run_until_parked();
1598 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1599 respond_tx.send(()).unwrap();
1600
1601 assert_eq!(reject_request.rejections.len(), 50);
1602 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1603 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1604
1605 // Remaining items are debounced with the next batch
1606 cx.executor().advance_clock(Duration::from_secs(15));
1607 cx.run_until_parked();
1608
1609 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1610 respond_tx.send(()).unwrap();
1611
1612 assert_eq!(reject_request.rejections.len(), 20);
1613 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1614 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1615
1616 // Request failure
1617 ep_store.update(cx, |ep_store, cx| {
1618 ep_store.reject_prediction(
1619 EditPredictionId("retry-1".into()),
1620 EditPredictionRejectReason::Discarded,
1621 false,
1622 None,
1623 cx,
1624 );
1625 });
1626
1627 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1628 cx.run_until_parked();
1629
1630 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1631 assert_eq!(reject_request.rejections.len(), 1);
1632 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1633 // Simulate failure
1634 drop(_respond_tx);
1635
1636 // Add another rejection
1637 ep_store.update(cx, |ep_store, cx| {
1638 ep_store.reject_prediction(
1639 EditPredictionId("retry-2".into()),
1640 EditPredictionRejectReason::Discarded,
1641 false,
1642 None,
1643 cx,
1644 );
1645 });
1646
1647 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1648 cx.run_until_parked();
1649
1650 // Retry should include both the failed item and the new one
1651 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1652 respond_tx.send(()).unwrap();
1653
1654 assert_eq!(reject_request.rejections.len(), 2);
1655 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1656 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1657}
1658
1659// Skipped until we start including diagnostics in prompt
1660// #[gpui::test]
1661// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1662// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1663// let fs = FakeFs::new(cx.executor());
1664// fs.insert_tree(
1665// "/root",
1666// json!({
1667// "foo.md": "Hello!\nBye"
1668// }),
1669// )
1670// .await;
1671// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1672
1673// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1674// let diagnostic = lsp::Diagnostic {
1675// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1676// severity: Some(lsp::DiagnosticSeverity::ERROR),
1677// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1678// ..Default::default()
1679// };
1680
1681// project.update(cx, |project, cx| {
1682// project.lsp_store().update(cx, |lsp_store, cx| {
1683// // Create some diagnostics
1684// lsp_store
1685// .update_diagnostics(
1686// LanguageServerId(0),
1687// lsp::PublishDiagnosticsParams {
1688// uri: path_to_buffer_uri.clone(),
1689// diagnostics: vec![diagnostic],
1690// version: None,
1691// },
1692// None,
1693// language::DiagnosticSourceKind::Pushed,
1694// &[],
1695// cx,
1696// )
1697// .unwrap();
1698// });
1699// });
1700
1701// let buffer = project
1702// .update(cx, |project, cx| {
1703// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1704// project.open_buffer(path, cx)
1705// })
1706// .await
1707// .unwrap();
1708
1709// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1710// let position = snapshot.anchor_before(language::Point::new(0, 0));
1711
1712// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1713// ep_store.request_prediction(&project, &buffer, position, cx)
1714// });
1715
1716// let (request, _respond_tx) = req_rx.next().await.unwrap();
1717
1718// assert_eq!(request.diagnostic_groups.len(), 1);
1719// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1720// .unwrap();
1721// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1722// assert_eq!(
1723// value,
1724// json!({
1725// "entries": [{
1726// "range": {
1727// "start": 8,
1728// "end": 10
1729// },
1730// "diagnostic": {
1731// "source": null,
1732// "code": null,
1733// "code_description": null,
1734// "severity": 1,
1735// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1736// "markdown": null,
1737// "group_id": 0,
1738// "is_primary": true,
1739// "is_disk_based": false,
1740// "is_unnecessary": false,
1741// "source_kind": "Pushed",
1742// "data": null,
1743// "underline": true
1744// }
1745// }],
1746// "primary_ix": 0
1747// })
1748// );
1749// }
1750
1751// Generate a model response that would apply the given diff to the active file.
1752fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1753 let editable_range =
1754 zeta_prompt::excerpt_range_for_format(Default::default(), &request.input.excerpt_ranges).1;
1755 let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string();
1756 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1757
1758 PredictEditsV3Response {
1759 request_id: Uuid::new_v4().to_string(),
1760 editable_range,
1761 output: new_excerpt,
1762 model_version: None,
1763 }
1764}
1765
1766fn empty_response() -> PredictEditsV3Response {
1767 PredictEditsV3Response {
1768 request_id: Uuid::new_v4().to_string(),
1769 editable_range: 0..0,
1770 output: String::new(),
1771 model_version: None,
1772 }
1773}
1774
1775fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1776 zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1777}
1778
1779fn assert_no_predict_request_ready(
1780 requests: &mut mpsc::UnboundedReceiver<(
1781 PredictEditsV3Request,
1782 oneshot::Sender<PredictEditsV3Response>,
1783 )>,
1784) {
1785 if requests.next().now_or_never().flatten().is_some() {
1786 panic!("Unexpected prediction request while throttled.");
1787 }
1788}
1789
1790struct RequestChannels {
1791 predict: mpsc::UnboundedReceiver<(
1792 PredictEditsV3Request,
1793 oneshot::Sender<PredictEditsV3Response>,
1794 )>,
1795 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1796}
1797
1798fn init_test_with_fake_client(
1799 cx: &mut TestAppContext,
1800) -> (Entity<EditPredictionStore>, RequestChannels) {
1801 cx.update(move |cx| {
1802 let settings_store = SettingsStore::test(cx);
1803 cx.set_global(settings_store);
1804 zlog::init_test();
1805
1806 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1807 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1808
1809 let http_client = FakeHttpClient::create({
1810 move |req| {
1811 let uri = req.uri().path().to_string();
1812 let mut body = req.into_body();
1813 let predict_req_tx = predict_req_tx.clone();
1814 let reject_req_tx = reject_req_tx.clone();
1815 async move {
1816 let resp = match uri.as_str() {
1817 "/client/llm_tokens" => serde_json::to_string(&json!({
1818 "token": "test"
1819 }))
1820 .unwrap(),
1821 "/predict_edits/v3" => {
1822 let mut buf = Vec::new();
1823 body.read_to_end(&mut buf).await.ok();
1824 let decompressed = zstd::decode_all(&buf[..]).unwrap();
1825 let req = serde_json::from_slice(&decompressed).unwrap();
1826
1827 let (res_tx, res_rx) = oneshot::channel();
1828 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1829 serde_json::to_string(&res_rx.await?).unwrap()
1830 }
1831 "/predict_edits/reject" => {
1832 let mut buf = Vec::new();
1833 body.read_to_end(&mut buf).await.ok();
1834 let req = serde_json::from_slice(&buf).unwrap();
1835
1836 let (res_tx, res_rx) = oneshot::channel();
1837 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1838 serde_json::to_string(&res_rx.await?).unwrap()
1839 }
1840 _ => {
1841 panic!("Unexpected path: {}", uri)
1842 }
1843 };
1844
1845 Ok(Response::builder().body(resp.into()).unwrap())
1846 }
1847 }
1848 });
1849
1850 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1851 client.cloud_client().set_credentials(1, "test".into());
1852
1853 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1854 language_model::init(user_store.clone(), client.clone(), cx);
1855 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1856
1857 (
1858 ep_store,
1859 RequestChannels {
1860 predict: predict_req_rx,
1861 reject: reject_req_rx,
1862 },
1863 )
1864 })
1865}
1866
1867#[gpui::test]
1868async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1869 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1870 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1871 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1872 });
1873
1874 let edit_preview = cx
1875 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1876 .await;
1877
1878 let prediction = EditPrediction {
1879 edits,
1880 cursor_position: None,
1881 edit_preview,
1882 buffer: buffer.clone(),
1883 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1884 id: EditPredictionId("the-id".into()),
1885 inputs: ZetaPromptInput {
1886 events: Default::default(),
1887 related_files: Default::default(),
1888 cursor_path: Path::new("").into(),
1889 cursor_excerpt: "".into(),
1890 cursor_offset_in_excerpt: 0,
1891 excerpt_start_row: None,
1892 excerpt_ranges: Default::default(),
1893 experiment: None,
1894 in_open_source_repo: false,
1895 can_collect_data: false,
1896 repo_url: None,
1897 },
1898 buffer_snapshotted_at: Instant::now(),
1899 response_received_at: Instant::now(),
1900 model_version: None,
1901 };
1902
1903 cx.update(|cx| {
1904 assert_eq!(
1905 from_completion_edits(
1906 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1907 &buffer,
1908 cx
1909 ),
1910 vec![(2..5, "REM".into()), (9..11, "".into())]
1911 );
1912
1913 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1914 assert_eq!(
1915 from_completion_edits(
1916 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1917 &buffer,
1918 cx
1919 ),
1920 vec![(2..2, "REM".into()), (6..8, "".into())]
1921 );
1922
1923 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1924 assert_eq!(
1925 from_completion_edits(
1926 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1927 &buffer,
1928 cx
1929 ),
1930 vec![(2..5, "REM".into()), (9..11, "".into())]
1931 );
1932
1933 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1934 assert_eq!(
1935 from_completion_edits(
1936 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1937 &buffer,
1938 cx
1939 ),
1940 vec![(3..3, "EM".into()), (7..9, "".into())]
1941 );
1942
1943 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1944 assert_eq!(
1945 from_completion_edits(
1946 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1947 &buffer,
1948 cx
1949 ),
1950 vec![(4..4, "M".into()), (8..10, "".into())]
1951 );
1952
1953 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1954 assert_eq!(
1955 from_completion_edits(
1956 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1957 &buffer,
1958 cx
1959 ),
1960 vec![(9..11, "".into())]
1961 );
1962
1963 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1964 assert_eq!(
1965 from_completion_edits(
1966 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1967 &buffer,
1968 cx
1969 ),
1970 vec![(4..4, "M".into()), (8..10, "".into())]
1971 );
1972
1973 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1974 assert_eq!(
1975 from_completion_edits(
1976 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1977 &buffer,
1978 cx
1979 ),
1980 vec![(4..4, "M".into())]
1981 );
1982
1983 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1984 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1985 })
1986}
1987
1988#[gpui::test]
1989async fn test_clean_up_diff(cx: &mut TestAppContext) {
1990 init_test(cx);
1991
1992 assert_eq!(
1993 apply_edit_prediction(
1994 indoc! {"
1995 fn main() {
1996 let word_1 = \"lorem\";
1997 let range = word.len()..word.len();
1998 }
1999 "},
2000 indoc! {"
2001 fn main() {
2002 let word_1 = \"lorem\";
2003 let range = word_1.len()..word_1.len();
2004 }
2005 "},
2006 cx,
2007 )
2008 .await,
2009 indoc! {"
2010 fn main() {
2011 let word_1 = \"lorem\";
2012 let range = word_1.len()..word_1.len();
2013 }
2014 "},
2015 );
2016
2017 assert_eq!(
2018 apply_edit_prediction(
2019 indoc! {"
2020 fn main() {
2021 let story = \"the quick\"
2022 }
2023 "},
2024 indoc! {"
2025 fn main() {
2026 let story = \"the quick brown fox jumps over the lazy dog\";
2027 }
2028 "},
2029 cx,
2030 )
2031 .await,
2032 indoc! {"
2033 fn main() {
2034 let story = \"the quick brown fox jumps over the lazy dog\";
2035 }
2036 "},
2037 );
2038}
2039
2040#[gpui::test]
2041async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2042 init_test(cx);
2043
2044 let buffer_content = "lorem\n";
2045 let completion_response = "lorem\nipsum\n";
2046
2047 assert_eq!(
2048 apply_edit_prediction(buffer_content, completion_response, cx).await,
2049 "lorem\nipsum\n"
2050 );
2051}
2052
2053#[gpui::test]
2054async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
2055 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
2056 // When the buffer ends without a trailing newline, but the model returns output
2057 // with a trailing newline, zeta2 should normalize both sides before diffing
2058 // so no spurious newline is inserted.
2059 let (ep_store, mut requests) = init_test_with_fake_client(cx);
2060 let fs = FakeFs::new(cx.executor());
2061
2062 // Single line buffer with no trailing newline
2063 fs.insert_tree(
2064 "/root",
2065 json!({
2066 "foo.txt": "hello"
2067 }),
2068 )
2069 .await;
2070 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2071
2072 let buffer = project
2073 .update(cx, |project, cx| {
2074 let path = project
2075 .find_project_path(path!("root/foo.txt"), cx)
2076 .unwrap();
2077 project.open_buffer(path, cx)
2078 })
2079 .await
2080 .unwrap();
2081
2082 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2083 let position = snapshot.anchor_before(language::Point::new(0, 5));
2084
2085 ep_store.update(cx, |ep_store, cx| {
2086 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2087 });
2088
2089 let (request, respond_tx) = requests.predict.next().await.unwrap();
2090
2091 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
2092 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
2093 let excerpt_length = request.input.cursor_excerpt.len();
2094 let response = PredictEditsV3Response {
2095 request_id: Uuid::new_v4().to_string(),
2096 output: "hello world\n".to_string(),
2097 editable_range: 0..excerpt_length,
2098 model_version: None,
2099 };
2100 respond_tx.send(response).unwrap();
2101
2102 cx.run_until_parked();
2103
2104 // The prediction should insert " world" without adding a newline
2105 ep_store.update(cx, |ep_store, cx| {
2106 let prediction = ep_store
2107 .prediction_at(&buffer, None, &project, cx)
2108 .expect("should have prediction");
2109 let edits: Vec<_> = prediction
2110 .edits
2111 .iter()
2112 .map(|(range, text)| {
2113 let snapshot = buffer.read(cx).snapshot();
2114 (range.to_offset(&snapshot), text.clone())
2115 })
2116 .collect();
2117 assert_eq!(edits, vec![(5..5, " world".into())]);
2118 });
2119}
2120
2121fn init_test(cx: &mut TestAppContext) {
2122 cx.update(|cx| {
2123 let settings_store = SettingsStore::test(cx);
2124 cx.set_global(settings_store);
2125 });
2126}
2127
2128async fn apply_edit_prediction(
2129 buffer_content: &str,
2130 completion_response: &str,
2131 cx: &mut TestAppContext,
2132) -> String {
2133 let fs = project::FakeFs::new(cx.executor());
2134 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2135 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2136 let (ep_store, response) = make_test_ep_store(&project, cx).await;
2137 *response.lock() = completion_response.to_string();
2138 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2139 buffer.update(cx, |buffer, cx| {
2140 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2141 });
2142 buffer.read_with(cx, |buffer, _| buffer.text())
2143}
2144
2145async fn run_edit_prediction(
2146 buffer: &Entity<Buffer>,
2147 project: &Entity<Project>,
2148 ep_store: &Entity<EditPredictionStore>,
2149 cx: &mut TestAppContext,
2150) -> EditPrediction {
2151 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2152 ep_store.update(cx, |ep_store, cx| {
2153 ep_store.register_buffer(buffer, &project, cx)
2154 });
2155 cx.background_executor.run_until_parked();
2156 let prediction_task = ep_store.update(cx, |ep_store, cx| {
2157 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2158 });
2159 prediction_task.await.unwrap().unwrap().prediction.unwrap()
2160}
2161
2162async fn make_test_ep_store(
2163 project: &Entity<Project>,
2164 cx: &mut TestAppContext,
2165) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
2166 let default_response = "hello world\n".to_string();
2167 let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
2168 let http_client = FakeHttpClient::create({
2169 let completion_response = completion_response.clone();
2170 let mut next_request_id = 0;
2171 move |req| {
2172 let completion_response = completion_response.clone();
2173 let method = req.method().clone();
2174 let uri = req.uri().path().to_string();
2175 let mut body = req.into_body();
2176 async move {
2177 match (method, uri.as_str()) {
2178 (Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2179 .status(200)
2180 .body(
2181 serde_json::to_string(&CreateLlmTokenResponse {
2182 token: LlmToken("the-llm-token".to_string()),
2183 })
2184 .unwrap()
2185 .into(),
2186 )
2187 .unwrap()),
2188 (Method::POST, "/predict_edits/v3") => {
2189 let mut buf = Vec::new();
2190 body.read_to_end(&mut buf).await.ok();
2191 let decompressed = zstd::decode_all(&buf[..]).unwrap();
2192 let req: PredictEditsV3Request =
2193 serde_json::from_slice(&decompressed).unwrap();
2194
2195 next_request_id += 1;
2196 Ok(http_client::Response::builder()
2197 .status(200)
2198 .body(
2199 serde_json::to_string(&PredictEditsV3Response {
2200 request_id: format!("request-{next_request_id}"),
2201 editable_range: 0..req.input.cursor_excerpt.len(),
2202 output: completion_response.lock().clone(),
2203 model_version: None,
2204 })
2205 .unwrap()
2206 .into(),
2207 )
2208 .unwrap())
2209 }
2210 _ => Ok(http_client::Response::builder()
2211 .status(404)
2212 .body("Not Found".to_string().into())
2213 .unwrap()),
2214 }
2215 }
2216 }
2217 });
2218
2219 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2220 let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
2221 cx.update(|cx| {
2222 RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
2223 });
2224 let _server = FakeServer::for_client(42, &client, cx).await;
2225
2226 let ep_store = cx.new(|cx| {
2227 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2228 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2229
2230 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2231 for worktree in worktrees {
2232 let worktree_id = worktree.read(cx).id();
2233 ep_store
2234 .get_or_init_project(project, cx)
2235 .license_detection_watchers
2236 .entry(worktree_id)
2237 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2238 }
2239
2240 ep_store
2241 });
2242
2243 (ep_store, completion_response)
2244}
2245
2246fn to_completion_edits(
2247 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2248 buffer: &Entity<Buffer>,
2249 cx: &App,
2250) -> Vec<(Range<Anchor>, Arc<str>)> {
2251 let buffer = buffer.read(cx);
2252 iterator
2253 .into_iter()
2254 .map(|(range, text)| {
2255 (
2256 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2257 text,
2258 )
2259 })
2260 .collect()
2261}
2262
2263fn from_completion_edits(
2264 editor_edits: &[(Range<Anchor>, Arc<str>)],
2265 buffer: &Entity<Buffer>,
2266 cx: &App,
2267) -> Vec<(Range<usize>, Arc<str>)> {
2268 let buffer = buffer.read(cx);
2269 editor_edits
2270 .iter()
2271 .map(|(range, text)| {
2272 (
2273 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2274 text.clone(),
2275 )
2276 })
2277 .collect()
2278}
2279
2280#[gpui::test]
2281async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2282 init_test(cx);
2283
2284 let fs = FakeFs::new(cx.executor());
2285 fs.insert_tree(
2286 "/project",
2287 serde_json::json!({
2288 "main.rs": "fn main() {\n \n}\n"
2289 }),
2290 )
2291 .await;
2292
2293 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2294
2295 let http_client = FakeHttpClient::create(|_req| async move {
2296 Ok(gpui::http_client::Response::builder()
2297 .status(401)
2298 .body("Unauthorized".into())
2299 .unwrap())
2300 });
2301
2302 let client =
2303 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2304 let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
2305 cx.update(|cx| {
2306 language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
2307 });
2308
2309 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2310
2311 let buffer = project
2312 .update(cx, |project, cx| {
2313 let path = project
2314 .find_project_path(path!("/project/main.rs"), cx)
2315 .unwrap();
2316 project.open_buffer(path, cx)
2317 })
2318 .await
2319 .unwrap();
2320
2321 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2322 ep_store.update(cx, |ep_store, cx| {
2323 ep_store.register_buffer(&buffer, &project, cx)
2324 });
2325 cx.background_executor.run_until_parked();
2326
2327 let completion_task = ep_store.update(cx, |ep_store, cx| {
2328 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2329 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2330 });
2331
2332 let result = completion_task.await;
2333 assert!(
2334 result.is_err(),
2335 "Without authentication and without custom URL, prediction should fail"
2336 );
2337}
2338
2339#[gpui::test]
2340fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2341 let buffer = cx.new(|cx| {
2342 Buffer::local(
2343 indoc! {"
2344 zero
2345 one
2346 two
2347 three
2348 four
2349 five
2350 six
2351 seven
2352 eight
2353 nine
2354 ten
2355 eleven
2356 twelve
2357 thirteen
2358 fourteen
2359 fifteen
2360 sixteen
2361 seventeen
2362 eighteen
2363 nineteen
2364 twenty
2365 twenty-one
2366 twenty-two
2367 twenty-three
2368 twenty-four
2369 "},
2370 cx,
2371 )
2372 });
2373
2374 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2375
2376 buffer.update(cx, |buffer, cx| {
2377 let point = Point::new(12, 0);
2378 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2379 let point = Point::new(8, 0);
2380 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2381 });
2382
2383 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2384
2385 let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2386
2387 assert_eq!(
2388 diff,
2389 indoc! {"
2390 @@ -6,10 +6,12 @@
2391 five
2392 six
2393 seven
2394 +FIRST INSERTION
2395 eight
2396 nine
2397 ten
2398 eleven
2399 +SECOND INSERTION
2400 twelve
2401 thirteen
2402 fourteen
2403 "}
2404 );
2405}
2406
2407#[gpui::test]
2408async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
2409 fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
2410 let collab_replica = clock::ReplicaId::new(10);
2411 let anchor = buffer.read_with(cx, |buffer, _| {
2412 buffer.snapshot().anchor_before(Point::new(row, 0))
2413 });
2414 let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
2415 id: 1,
2416 start: anchor,
2417 end: anchor,
2418 reversed: false,
2419 goal: SelectionGoal::None,
2420 }]);
2421 buffer.update(cx, |buffer, cx| {
2422 buffer.apply_ops(
2423 [Operation::UpdateSelections {
2424 selections,
2425 lamport_timestamp: clock::Lamport {
2426 replica_id: collab_replica,
2427 value: 1,
2428 },
2429 line_mode: false,
2430 cursor_shape: CursorShape::Bar,
2431 }],
2432 cx,
2433 );
2434 });
2435 }
2436
2437 fn publish_diagnostics(
2438 uri_path: &'static str,
2439 rows: &[u32],
2440 project: &Entity<Project>,
2441 cx: &mut TestAppContext,
2442 ) {
2443 let diagnostics: Vec<_> = rows
2444 .iter()
2445 .map(|&row| lsp::Diagnostic {
2446 range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
2447 severity: Some(lsp::DiagnosticSeverity::ERROR),
2448 message: format!("error at row {row}"),
2449 ..Default::default()
2450 })
2451 .collect();
2452 project.update(cx, |project, cx| {
2453 project.lsp_store().update(cx, |lsp_store, cx| {
2454 lsp_store
2455 .update_diagnostics(
2456 LanguageServerId(0),
2457 lsp::PublishDiagnosticsParams {
2458 uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
2459 diagnostics,
2460 version: None,
2461 },
2462 None,
2463 language::DiagnosticSourceKind::Pushed,
2464 &[],
2465 cx,
2466 )
2467 .expect("failed to update diagnostics");
2468 });
2469 });
2470 }
2471
2472 init_test(cx);
2473
2474 let mut lines = String::new();
2475 for i in 0..60 {
2476 lines.push_str(&format!("line {i}\n"));
2477 }
2478
2479 let fs = FakeFs::new(cx.executor());
2480 fs.insert_tree(
2481 "/root",
2482 json!({
2483 "active.txt": lines,
2484 "collab_file.txt": "error here\nsecond line\n",
2485 "free_file.txt": "another error\nsecond line\n",
2486 }),
2487 )
2488 .await;
2489 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2490
2491 let active_buffer = project
2492 .update(cx, |project, cx| {
2493 let path = project
2494 .find_project_path(path!("/root/active.txt"), cx)
2495 .expect("active.txt not found");
2496 project.set_active_path(Some(path.clone()), cx);
2497 project.open_buffer(path, cx)
2498 })
2499 .await
2500 .expect("failed to open active buffer");
2501
2502 set_collaborator_cursor(&active_buffer, 5, cx);
2503
2504 publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
2505
2506 cx.run_until_parked();
2507
2508 let cursor_point = Point::new(25, 0);
2509 let empty_search_range: Range<Point> = Default::default();
2510
2511 let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2512 let result = EditPredictionStore::next_diagnostic_location(
2513 active_buffer.clone(),
2514 &snapshot,
2515 empty_search_range.clone(),
2516 cursor_point,
2517 &project,
2518 &mut cx.to_async(),
2519 )
2520 .await
2521 .expect("next_diagnostic_location failed");
2522
2523 let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
2524 assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
2525 let result_row = result_buffer.read_with(cx, |buffer, _| {
2526 result_anchor.to_point(&buffer.snapshot()).row
2527 });
2528 assert_ne!(
2529 result_row, 3,
2530 "row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
2531 );
2532 assert!(
2533 result_row == 25 || result_row == 50,
2534 "expected row 25 or 50, got {result_row}"
2535 );
2536
2537 let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2538 let near_cursor_point = Point::new(4, 0);
2539 let result_near = EditPredictionStore::next_diagnostic_location(
2540 active_buffer.clone(),
2541 &snapshot_near,
2542 empty_search_range.clone(),
2543 near_cursor_point,
2544 &project,
2545 &mut cx.to_async(),
2546 )
2547 .await
2548 .expect("next_diagnostic_location failed");
2549
2550 let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
2551 let near_row =
2552 active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
2553 assert_eq!(
2554 near_row, 3,
2555 "row 3 should be included when local cursor (row 4) is also near the collaborator"
2556 );
2557
2558 let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2559 let far_cursor_point = Point::new(50, 0);
2560 let result_far = EditPredictionStore::next_diagnostic_location(
2561 active_buffer.clone(),
2562 &snapshot_far,
2563 empty_search_range.clone(),
2564 far_cursor_point,
2565 &project,
2566 &mut cx.to_async(),
2567 )
2568 .await
2569 .expect("next_diagnostic_location failed");
2570
2571 let (_, far_anchor) = result_far.expect("expected a diagnostic location");
2572 let far_row =
2573 active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
2574 assert_eq!(
2575 far_row, 50,
2576 "row 50 is near local cursor (row 50) and far from collaborator, should be picked"
2577 );
2578
2579 publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
2580 publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
2581 cx.run_until_parked();
2582
2583 let collab_buffer = project
2584 .update(cx, |project, cx| {
2585 let path = project
2586 .find_project_path(path!("/root/collab_file.txt"), cx)
2587 .expect("collab_file.txt not found");
2588 project.open_buffer(path, cx)
2589 })
2590 .await
2591 .expect("failed to open collab buffer");
2592
2593 set_collaborator_cursor(&collab_buffer, 0, cx);
2594 cx.run_until_parked();
2595
2596 let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
2597 let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2598 let result_cross = EditPredictionStore::next_diagnostic_location(
2599 active_buffer.clone(),
2600 &snapshot_cross,
2601 no_same_file_search_range,
2602 Point::new(0, 0),
2603 &project,
2604 &mut cx.to_async(),
2605 )
2606 .await
2607 .expect("cross-file next_diagnostic_location failed");
2608
2609 let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
2610 let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
2611 buffer
2612 .file()
2613 .expect("buffer should have a file")
2614 .full_path(cx)
2615 });
2616 assert_eq!(
2617 cross_path,
2618 Path::new(path!("root/free_file.txt")),
2619 "should skip collab_file.txt (has collaborator) and pick free_file.txt"
2620 );
2621}
2622
2623#[gpui::test]
2624async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
2625 let (ep_store, _requests) = init_test_with_fake_client(cx);
2626 let fs = FakeFs::new(cx.executor());
2627
2628 // Buffer with two clearly separated regions:
2629 // Region A = lines 0-9 (offsets 0..50)
2630 // Region B = lines 20-29 (offsets 105..155)
2631 // A big gap in between so edits in one region never overlap the other.
2632 let mut content = String::new();
2633 for i in 0..30 {
2634 content.push_str(&format!("line {i:02}\n"));
2635 }
2636
2637 fs.insert_tree(
2638 "/root",
2639 json!({
2640 "foo.md": content.clone()
2641 }),
2642 )
2643 .await;
2644 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2645
2646 let buffer = project
2647 .update(cx, |project, cx| {
2648 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2649 project.open_buffer(path, cx)
2650 })
2651 .await
2652 .unwrap();
2653
2654 type SettledEventRecord = (EditPredictionId, String);
2655 let settled_events: Arc<Mutex<Vec<SettledEventRecord>>> = Arc::new(Mutex::new(Vec::new()));
2656
2657 ep_store.update(cx, |ep_store, cx| {
2658 ep_store.register_buffer(&buffer, &project, cx);
2659
2660 let settled_events = settled_events.clone();
2661 ep_store.settled_event_callback = Some(Box::new(move |id, text| {
2662 settled_events.lock().push((id, text));
2663 }));
2664 });
2665
2666 // --- Phase 1: edit in region A and enqueue prediction A ---
2667
2668 buffer.update(cx, |buffer, cx| {
2669 // Edit at the start of line 0.
2670 buffer.edit(vec![(0..0, "ADDED ")], None, cx);
2671 });
2672 cx.run_until_parked();
2673
2674 let snapshot_a = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2675
2676 // Region A: first 10 lines of the buffer.
2677 let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0));
2678
2679 ep_store.update(cx, |ep_store, cx| {
2680 ep_store.enqueue_settled_prediction(
2681 EditPredictionId("prediction-a".into()),
2682 &project,
2683 &buffer,
2684 &snapshot_a,
2685 editable_region_a.clone(),
2686 None,
2687 cx,
2688 );
2689 });
2690
2691 // --- Phase 2: repeatedly edit in region A to keep it unsettled ---
2692
2693 // Let the worker process the channel message before we start advancing.
2694 cx.run_until_parked();
2695
2696 let mut region_a_edit_offset = 5;
2697 for _ in 0..3 {
2698 // Edit inside region A (not at the boundary) so `last_edit_at` is
2699 // updated before the worker's next wake.
2700 buffer.update(cx, |buffer, cx| {
2701 buffer.edit(
2702 vec![(region_a_edit_offset..region_a_edit_offset, "x")],
2703 None,
2704 cx,
2705 );
2706 });
2707 region_a_edit_offset += 1;
2708 cx.run_until_parked();
2709
2710 cx.executor()
2711 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 2);
2712 cx.run_until_parked();
2713 assert!(
2714 settled_events.lock().is_empty(),
2715 "no settled events should fire while region A is still being edited"
2716 );
2717 }
2718
2719 // Still nothing settled.
2720 assert!(settled_events.lock().is_empty());
2721
2722 // --- Phase 3: edit in distinct region B, enqueue prediction B ---
2723 // Advance a small amount so B's quiescence window starts later than A's,
2724 // but not so much that A settles (A's last edit was at the start of
2725 // iteration 3, and it needs a full Q to settle).
2726 cx.executor()
2727 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2728 cx.run_until_parked();
2729 assert!(settled_events.lock().is_empty());
2730
2731 let snapshot_b = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2732 let line_20_offset = snapshot_b.point_to_offset(Point::new(20, 0));
2733
2734 buffer.update(cx, |buffer, cx| {
2735 buffer.edit(vec![(line_20_offset..line_20_offset, "NEW ")], None, cx);
2736 });
2737 cx.run_until_parked();
2738
2739 let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2740 let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0));
2741
2742 ep_store.update(cx, |ep_store, cx| {
2743 ep_store.enqueue_settled_prediction(
2744 EditPredictionId("prediction-b".into()),
2745 &project,
2746 &buffer,
2747 &snapshot_b2,
2748 editable_region_b.clone(),
2749 None,
2750 cx,
2751 );
2752 });
2753
2754 cx.run_until_parked();
2755 assert!(
2756 settled_events.lock().is_empty(),
2757 "neither prediction should have settled yet"
2758 );
2759
2760 // --- Phase 4: let enough time pass for region A to settle ---
2761 // A's last edit was at T_a (during the last loop iteration). The worker is
2762 // sleeping until T_a + Q. We advance just enough to reach that wake time
2763 // (Q/4 since we already advanced Q/4 in phase 3 on top of the loop's
2764 // 3*Q/2). At that point A has been quiet for Q and settles, but B was
2765 // enqueued only Q/4 ago and stays pending.
2766 cx.executor()
2767 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2768 cx.run_until_parked();
2769
2770 {
2771 let events = settled_events.lock().clone();
2772 assert_eq!(
2773 events.len(),
2774 1,
2775 "prediction and capture_sample for A should have settled, got: {events:?}"
2776 );
2777 assert_eq!(events[0].0, EditPredictionId("prediction-a".into()));
2778 }
2779
2780 // --- Phase 5: let more time pass for region B to settle ---
2781 // B's last edit was Q/4 before A settled. The worker rescheduled to
2782 // B's last_edit_at + Q, which is 3Q/4 from now.
2783 cx.executor()
2784 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE * 3 / 4);
2785 cx.run_until_parked();
2786
2787 {
2788 let events = settled_events.lock().clone();
2789 assert_eq!(
2790 events.len(),
2791 2,
2792 "both prediction and capture_sample settled events should be emitted for each request, got: {events:?}"
2793 );
2794 assert_eq!(events[1].0, EditPredictionId("prediction-b".into()));
2795 }
2796}
2797
2798#[ctor::ctor]
2799fn init_logger() {
2800 zlog::init_test();
2801}