1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
3use client::{UserStore, test::FakeServer};
4use clock::FakeSystemClock;
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
8 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
9};
10use futures::{
11 AsyncReadExt, FutureExt, StreamExt,
12 channel::{mpsc, oneshot},
13};
14use gpui::App;
15use gpui::{
16 Entity, TestAppContext,
17 http_client::{FakeHttpClient, Response},
18};
19use indoc::indoc;
20use language::{
21 Anchor, Buffer, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet, DiagnosticSeverity,
22 Operation, Point, Selection, SelectionGoal,
23};
24use lsp::LanguageServerId;
25use parking_lot::Mutex;
26use pretty_assertions::{assert_eq, assert_matches};
27use project::{FakeFs, Project};
28use serde_json::json;
29use settings::SettingsStore;
30use std::{path::Path, sync::Arc, time::Duration};
31use util::{
32 path,
33 test::{TextRangeMarker, marked_text_ranges_by},
34};
35use uuid::Uuid;
36use zeta_prompt::ZetaPromptInput;
37
38use crate::{
39 BufferEditPrediction, EDIT_PREDICTION_SETTLED_QUIESCENCE, EditPredictionId,
40 EditPredictionStore, REJECT_REQUEST_DEBOUNCE,
41};
42
43#[gpui::test]
44async fn test_current_state(cx: &mut TestAppContext) {
45 let (ep_store, mut requests) = init_test_with_fake_client(cx);
46 let fs = FakeFs::new(cx.executor());
47 fs.insert_tree(
48 "/root",
49 json!({
50 "1.txt": "Hello!\nHow\nBye\n",
51 "2.txt": "Hola!\nComo\nAdios\n"
52 }),
53 )
54 .await;
55 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
56
57 let buffer1 = project
58 .update(cx, |project, cx| {
59 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
60 project.set_active_path(Some(path.clone()), cx);
61 project.open_buffer(path, cx)
62 })
63 .await
64 .unwrap();
65 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
66 let position = snapshot1.anchor_before(language::Point::new(1, 3));
67
68 ep_store.update(cx, |ep_store, cx| {
69 ep_store.register_project(&project, cx);
70 ep_store.register_buffer(&buffer1, &project, cx);
71 });
72
73 // Prediction for current file
74
75 ep_store.update(cx, |ep_store, cx| {
76 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
77 });
78 let (request, respond_tx) = requests.predict.next().await.unwrap();
79
80 respond_tx
81 .send(model_response(
82 &request,
83 indoc! {r"
84 --- a/root/1.txt
85 +++ b/root/1.txt
86 @@ ... @@
87 Hello!
88 -How
89 +How are you?
90 Bye
91 "},
92 ))
93 .unwrap();
94
95 cx.run_until_parked();
96
97 ep_store.update(cx, |ep_store, cx| {
98 let prediction = ep_store
99 .prediction_at(&buffer1, None, &project, cx)
100 .unwrap();
101 assert_matches!(prediction, BufferEditPrediction::Local { .. });
102 });
103
104 ep_store.update(cx, |ep_store, cx| {
105 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
106 });
107
108 // Prediction for diagnostic in another file
109
110 let diagnostic = lsp::Diagnostic {
111 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
112 severity: Some(lsp::DiagnosticSeverity::ERROR),
113 message: "Sentence is incomplete".to_string(),
114 ..Default::default()
115 };
116
117 project.update(cx, |project, cx| {
118 project.lsp_store().update(cx, |lsp_store, cx| {
119 lsp_store
120 .update_diagnostics(
121 LanguageServerId(0),
122 lsp::PublishDiagnosticsParams {
123 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
124 diagnostics: vec![diagnostic],
125 version: None,
126 },
127 None,
128 language::DiagnosticSourceKind::Pushed,
129 &[],
130 cx,
131 )
132 .unwrap();
133 });
134 });
135
136 let (request, respond_tx) = requests.predict.next().await.unwrap();
137 respond_tx
138 .send(model_response(
139 &request,
140 indoc! {r#"
141 --- a/root/2.txt
142 +++ b/root/2.txt
143 @@ ... @@
144 Hola!
145 -Como
146 +Como estas?
147 Adios
148 "#},
149 ))
150 .unwrap();
151 cx.run_until_parked();
152
153 ep_store.update(cx, |ep_store, cx| {
154 let prediction = ep_store
155 .prediction_at(&buffer1, None, &project, cx)
156 .unwrap();
157 assert_matches!(
158 prediction,
159 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
160 );
161 });
162
163 let buffer2 = project
164 .update(cx, |project, cx| {
165 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
166 project.open_buffer(path, cx)
167 })
168 .await
169 .unwrap();
170
171 ep_store.update(cx, |ep_store, cx| {
172 let prediction = ep_store
173 .prediction_at(&buffer2, None, &project, cx)
174 .unwrap();
175 assert_matches!(prediction, BufferEditPrediction::Local { .. });
176 });
177}
178
179#[gpui::test]
180async fn test_simple_request(cx: &mut TestAppContext) {
181 let (ep_store, mut requests) = init_test_with_fake_client(cx);
182 let fs = FakeFs::new(cx.executor());
183 fs.insert_tree(
184 "/root",
185 json!({
186 "foo.md": "Hello!\nHow\nBye\n"
187 }),
188 )
189 .await;
190 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
191
192 let buffer = project
193 .update(cx, |project, cx| {
194 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
195 project.open_buffer(path, cx)
196 })
197 .await
198 .unwrap();
199 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
200 let position = snapshot.anchor_before(language::Point::new(1, 3));
201
202 let prediction_task = ep_store.update(cx, |ep_store, cx| {
203 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
204 });
205
206 let (request, respond_tx) = requests.predict.next().await.unwrap();
207
208 // TODO Put back when we have a structured request again
209 // assert_eq!(
210 // request.excerpt_path.as_ref(),
211 // Path::new(path!("root/foo.md"))
212 // );
213 // assert_eq!(
214 // request.cursor_point,
215 // Point {
216 // line: Line(1),
217 // column: 3
218 // }
219 // );
220
221 respond_tx
222 .send(model_response(
223 &request,
224 indoc! { r"
225 --- a/root/foo.md
226 +++ b/root/foo.md
227 @@ ... @@
228 Hello!
229 -How
230 +How are you?
231 Bye
232 "},
233 ))
234 .unwrap();
235
236 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
237
238 assert_eq!(prediction.edits.len(), 1);
239 assert_eq!(
240 prediction.edits[0].0.to_point(&snapshot).start,
241 language::Point::new(1, 3)
242 );
243 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
244}
245
246#[gpui::test]
247async fn test_request_events(cx: &mut TestAppContext) {
248 let (ep_store, mut requests) = init_test_with_fake_client(cx);
249 let fs = FakeFs::new(cx.executor());
250 fs.insert_tree(
251 "/root",
252 json!({
253 "foo.md": "Hello!\n\nBye\n"
254 }),
255 )
256 .await;
257 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
258
259 let buffer = project
260 .update(cx, |project, cx| {
261 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
262 project.open_buffer(path, cx)
263 })
264 .await
265 .unwrap();
266
267 ep_store.update(cx, |ep_store, cx| {
268 ep_store.register_buffer(&buffer, &project, cx);
269 });
270
271 buffer.update(cx, |buffer, cx| {
272 buffer.edit(vec![(7..7, "How")], None, cx);
273 });
274
275 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
276 let position = snapshot.anchor_before(language::Point::new(1, 3));
277
278 let prediction_task = ep_store.update(cx, |ep_store, cx| {
279 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
280 });
281
282 let (request, respond_tx) = requests.predict.next().await.unwrap();
283
284 let prompt = prompt_from_request(&request);
285 assert!(
286 prompt.contains(indoc! {"
287 --- a/root/foo.md
288 +++ b/root/foo.md
289 @@ -1,3 +1,3 @@
290 Hello!
291 -
292 +How
293 Bye
294 "}),
295 "{prompt}"
296 );
297
298 respond_tx
299 .send(model_response(
300 &request,
301 indoc! {r#"
302 --- a/root/foo.md
303 +++ b/root/foo.md
304 @@ ... @@
305 Hello!
306 -How
307 +How are you?
308 Bye
309 "#},
310 ))
311 .unwrap();
312
313 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
314
315 assert_eq!(prediction.edits.len(), 1);
316 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
317}
318
319#[gpui::test]
320async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
321 let (ep_store, _requests) = init_test_with_fake_client(cx);
322 let fs = FakeFs::new(cx.executor());
323 fs.insert_tree(
324 "/root",
325 json!({
326 "foo.md": "Hello!\n\nBye\n"
327 }),
328 )
329 .await;
330 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
331
332 let buffer = project
333 .update(cx, |project, cx| {
334 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
335 project.open_buffer(path, cx)
336 })
337 .await
338 .unwrap();
339
340 ep_store.update(cx, |ep_store, cx| {
341 ep_store.register_buffer(&buffer, &project, cx);
342 });
343
344 // First burst: insert "How"
345 buffer.update(cx, |buffer, cx| {
346 buffer.edit(vec![(7..7, "How")], None, cx);
347 });
348
349 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
350 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
351 cx.run_until_parked();
352
353 // Second burst: append " are you?" immediately after "How" on the same line.
354 //
355 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
356 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
357 buffer.update(cx, |buffer, cx| {
358 buffer.edit(vec![(10..10, " are you?")], None, cx);
359 });
360
361 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
362 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
363 buffer.update(cx, |buffer, cx| {
364 buffer.edit(vec![(19..19, "!")], None, cx);
365 });
366
367 // With time-based splitting, there are two distinct events.
368 let events = ep_store.update(cx, |ep_store, cx| {
369 ep_store.edit_history_for_project(&project, cx)
370 });
371 assert_eq!(events.len(), 2);
372 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
373 assert_eq!(
374 diff.as_str(),
375 indoc! {"
376 @@ -1,3 +1,3 @@
377 Hello!
378 -
379 +How
380 Bye
381 "}
382 );
383
384 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
385 assert_eq!(
386 diff.as_str(),
387 indoc! {"
388 @@ -1,3 +1,3 @@
389 Hello!
390 -How
391 +How are you?!
392 Bye
393 "}
394 );
395}
396
397#[gpui::test]
398async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
399 let (ep_store, _requests) = init_test_with_fake_client(cx);
400 let fs = FakeFs::new(cx.executor());
401
402 // Create a file with 30 lines to test line-based coalescing
403 let content = (1..=30)
404 .map(|i| format!("Line {}\n", i))
405 .collect::<String>();
406 fs.insert_tree(
407 "/root",
408 json!({
409 "foo.md": content
410 }),
411 )
412 .await;
413 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
414
415 let buffer = project
416 .update(cx, |project, cx| {
417 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
418 project.open_buffer(path, cx)
419 })
420 .await
421 .unwrap();
422
423 ep_store.update(cx, |ep_store, cx| {
424 ep_store.register_buffer(&buffer, &project, cx);
425 });
426
427 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
428 buffer.update(cx, |buffer, cx| {
429 let start = Point::new(10, 0).to_offset(buffer);
430 let end = Point::new(13, 0).to_offset(buffer);
431 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
432 });
433
434 let events = ep_store.update(cx, |ep_store, cx| {
435 ep_store.edit_history_for_project(&project, cx)
436 });
437 assert_eq!(
438 render_events(&events),
439 indoc! {"
440 @@ -8,9 +8,8 @@
441 Line 8
442 Line 9
443 Line 10
444 -Line 11
445 -Line 12
446 -Line 13
447 +Middle A
448 +Middle B
449 Line 14
450 Line 15
451 Line 16
452 "},
453 "After first edit"
454 );
455
456 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
457 // This tests that coalescing considers the START of the existing range
458 buffer.update(cx, |buffer, cx| {
459 let offset = Point::new(5, 0).to_offset(buffer);
460 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
461 });
462
463 let events = ep_store.update(cx, |ep_store, cx| {
464 ep_store.edit_history_for_project(&project, cx)
465 });
466 assert_eq!(
467 render_events(&events),
468 indoc! {"
469 @@ -3,14 +3,14 @@
470 Line 3
471 Line 4
472 Line 5
473 +Above
474 Line 6
475 Line 7
476 Line 8
477 Line 9
478 Line 10
479 -Line 11
480 -Line 12
481 -Line 13
482 +Middle A
483 +Middle B
484 Line 14
485 Line 15
486 Line 16
487 "},
488 "After inserting above (should coalesce)"
489 );
490
491 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
492 // This tests that coalescing considers the END of the existing range
493 buffer.update(cx, |buffer, cx| {
494 let offset = Point::new(14, 0).to_offset(buffer);
495 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
496 });
497
498 let events = ep_store.update(cx, |ep_store, cx| {
499 ep_store.edit_history_for_project(&project, cx)
500 });
501 assert_eq!(
502 render_events(&events),
503 indoc! {"
504 @@ -3,15 +3,16 @@
505 Line 3
506 Line 4
507 Line 5
508 +Above
509 Line 6
510 Line 7
511 Line 8
512 Line 9
513 Line 10
514 -Line 11
515 -Line 12
516 -Line 13
517 +Middle A
518 +Middle B
519 Line 14
520 +Below
521 Line 15
522 Line 16
523 Line 17
524 "},
525 "After inserting below (should coalesce)"
526 );
527
528 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
529 // This should NOT coalesce - creates a new event
530 buffer.update(cx, |buffer, cx| {
531 let offset = Point::new(25, 0).to_offset(buffer);
532 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
533 });
534
535 let events = ep_store.update(cx, |ep_store, cx| {
536 ep_store.edit_history_for_project(&project, cx)
537 });
538 assert_eq!(
539 render_events(&events),
540 indoc! {"
541 @@ -3,15 +3,16 @@
542 Line 3
543 Line 4
544 Line 5
545 +Above
546 Line 6
547 Line 7
548 Line 8
549 Line 9
550 Line 10
551 -Line 11
552 -Line 12
553 -Line 13
554 +Middle A
555 +Middle B
556 Line 14
557 +Below
558 Line 15
559 Line 16
560 Line 17
561
562 ---
563 @@ -23,6 +23,7 @@
564 Line 22
565 Line 23
566 Line 24
567 +Far below
568 Line 25
569 Line 26
570 Line 27
571 "},
572 "After inserting far below (should NOT coalesce)"
573 );
574}
575
576fn render_events(events: &[StoredEvent]) -> String {
577 events
578 .iter()
579 .map(|e| {
580 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
581 diff.as_str()
582 })
583 .collect::<Vec<_>>()
584 .join("\n---\n")
585}
586
587fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
588 events
589 .iter()
590 .map(|e| {
591 let zeta_prompt::Event::BufferChange {
592 diff, predicted, ..
593 } = e.event.as_ref();
594 let prefix = if *predicted { "predicted" } else { "manual" };
595 format!("{}\n{}", prefix, diff)
596 })
597 .collect()
598}
599
600#[gpui::test]
601async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
602 let (ep_store, _requests) = init_test_with_fake_client(cx);
603 let fs = FakeFs::new(cx.executor());
604 fs.insert_tree(
605 "/root",
606 json!({
607 "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
608 }),
609 )
610 .await;
611 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
612
613 let buffer = project
614 .update(cx, |project, cx| {
615 let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
616 project.open_buffer(path, cx)
617 })
618 .await
619 .unwrap();
620
621 ep_store.update(cx, |ep_store, cx| {
622 ep_store.register_buffer(&buffer, &project, cx);
623 });
624
625 // Case 1: Manual edits have `predicted` set to false.
626 buffer.update(cx, |buffer, cx| {
627 buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
628 });
629
630 let events = ep_store.update(cx, |ep_store, cx| {
631 ep_store.edit_history_for_project(&project, cx)
632 });
633
634 assert_eq!(
635 render_events_with_predicted(&events),
636 vec![indoc! {"
637 manual
638 @@ -1,4 +1,4 @@
639 -line 0
640 +LINE ZERO
641 line 1
642 line 2
643 line 3
644 "}]
645 );
646
647 // Case 2: Multiple successive manual edits near each other are merged into one
648 // event with `predicted` set to false.
649 buffer.update(cx, |buffer, cx| {
650 let offset = Point::new(1, 0).to_offset(buffer);
651 let end = Point::new(1, 6).to_offset(buffer);
652 buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
653 });
654
655 let events = ep_store.update(cx, |ep_store, cx| {
656 ep_store.edit_history_for_project(&project, cx)
657 });
658 assert_eq!(
659 render_events_with_predicted(&events),
660 vec![indoc! {"
661 manual
662 @@ -1,5 +1,5 @@
663 -line 0
664 -line 1
665 +LINE ZERO
666 +LINE ONE
667 line 2
668 line 3
669 line 4
670 "}]
671 );
672
673 // Case 3: Accepted predictions have `predicted` set to true.
674 // Case 5: A manual edit that follows a predicted edit is not merged with the
675 // predicted edit, even if it is nearby.
676 ep_store.update(cx, |ep_store, cx| {
677 buffer.update(cx, |buffer, cx| {
678 let offset = Point::new(2, 0).to_offset(buffer);
679 let end = Point::new(2, 6).to_offset(buffer);
680 buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
681 });
682 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
683 });
684
685 let events = ep_store.update(cx, |ep_store, cx| {
686 ep_store.edit_history_for_project(&project, cx)
687 });
688 assert_eq!(
689 render_events_with_predicted(&events),
690 vec![
691 indoc! {"
692 manual
693 @@ -1,5 +1,5 @@
694 -line 0
695 -line 1
696 +LINE ZERO
697 +LINE ONE
698 line 2
699 line 3
700 line 4
701 "},
702 indoc! {"
703 predicted
704 @@ -1,6 +1,6 @@
705 LINE ZERO
706 LINE ONE
707 -line 2
708 +LINE TWO
709 line 3
710 line 4
711 line 5
712 "}
713 ]
714 );
715
716 // Case 4: Multiple successive accepted predictions near each other are merged
717 // into one event with `predicted` set to true.
718 ep_store.update(cx, |ep_store, cx| {
719 buffer.update(cx, |buffer, cx| {
720 let offset = Point::new(3, 0).to_offset(buffer);
721 let end = Point::new(3, 6).to_offset(buffer);
722 buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
723 });
724 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
725 });
726
727 let events = ep_store.update(cx, |ep_store, cx| {
728 ep_store.edit_history_for_project(&project, cx)
729 });
730 assert_eq!(
731 render_events_with_predicted(&events),
732 vec![
733 indoc! {"
734 manual
735 @@ -1,5 +1,5 @@
736 -line 0
737 -line 1
738 +LINE ZERO
739 +LINE ONE
740 line 2
741 line 3
742 line 4
743 "},
744 indoc! {"
745 predicted
746 @@ -1,7 +1,7 @@
747 LINE ZERO
748 LINE ONE
749 -line 2
750 -line 3
751 +LINE TWO
752 +LINE THREE
753 line 4
754 line 5
755 line 6
756 "}
757 ]
758 );
759
760 // Case 5 (continued): A manual edit that follows a predicted edit is not merged
761 // with the predicted edit, even if it is nearby.
762 buffer.update(cx, |buffer, cx| {
763 let offset = Point::new(4, 0).to_offset(buffer);
764 let end = Point::new(4, 6).to_offset(buffer);
765 buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
766 });
767
768 let events = ep_store.update(cx, |ep_store, cx| {
769 ep_store.edit_history_for_project(&project, cx)
770 });
771 assert_eq!(
772 render_events_with_predicted(&events),
773 vec![
774 indoc! {"
775 manual
776 @@ -1,5 +1,5 @@
777 -line 0
778 -line 1
779 +LINE ZERO
780 +LINE ONE
781 line 2
782 line 3
783 line 4
784 "},
785 indoc! {"
786 predicted
787 @@ -1,7 +1,7 @@
788 LINE ZERO
789 LINE ONE
790 -line 2
791 -line 3
792 +LINE TWO
793 +LINE THREE
794 line 4
795 line 5
796 line 6
797 "},
798 indoc! {"
799 manual
800 @@ -2,7 +2,7 @@
801 LINE ONE
802 LINE TWO
803 LINE THREE
804 -line 4
805 +LINE FOUR
806 line 5
807 line 6
808 line 7
809 "}
810 ]
811 );
812
813 // Case 6: If we then perform a manual edit at a *different* location (more than
814 // 8 lines away), then the edits at the prior location can be merged with each
815 // other, even if some are predicted and some are not. `predicted` means all
816 // constituent edits were predicted.
817 buffer.update(cx, |buffer, cx| {
818 let offset = Point::new(14, 0).to_offset(buffer);
819 let end = Point::new(14, 7).to_offset(buffer);
820 buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
821 });
822
823 let events = ep_store.update(cx, |ep_store, cx| {
824 ep_store.edit_history_for_project(&project, cx)
825 });
826 assert_eq!(
827 render_events_with_predicted(&events),
828 vec![
829 indoc! {"
830 manual
831 @@ -1,8 +1,8 @@
832 -line 0
833 -line 1
834 -line 2
835 -line 3
836 -line 4
837 +LINE ZERO
838 +LINE ONE
839 +LINE TWO
840 +LINE THREE
841 +LINE FOUR
842 line 5
843 line 6
844 line 7
845 "},
846 indoc! {"
847 manual
848 @@ -12,4 +12,4 @@
849 line 11
850 line 12
851 line 13
852 -line 14
853 +LINE FOURTEEN
854 "}
855 ]
856 );
857}
858
859#[gpui::test]
860async fn test_empty_prediction(cx: &mut TestAppContext) {
861 let (ep_store, mut requests) = init_test_with_fake_client(cx);
862 let fs = FakeFs::new(cx.executor());
863 fs.insert_tree(
864 "/root",
865 json!({
866 "foo.md": "Hello!\nHow\nBye\n"
867 }),
868 )
869 .await;
870 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
871
872 let buffer = project
873 .update(cx, |project, cx| {
874 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
875 project.open_buffer(path, cx)
876 })
877 .await
878 .unwrap();
879 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
880 let position = snapshot.anchor_before(language::Point::new(1, 3));
881
882 ep_store.update(cx, |ep_store, cx| {
883 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
884 });
885
886 let (request, respond_tx) = requests.predict.next().await.unwrap();
887 let response = model_response(&request, "");
888 let id = response.request_id.clone();
889 respond_tx.send(response).unwrap();
890
891 cx.run_until_parked();
892
893 ep_store.update(cx, |ep_store, cx| {
894 assert!(
895 ep_store
896 .prediction_at(&buffer, None, &project, cx)
897 .is_none()
898 );
899 });
900
901 // prediction is reported as rejected
902 let (reject_request, _) = requests.reject.next().await.unwrap();
903
904 assert_eq!(
905 &reject_request.rejections,
906 &[EditPredictionRejection {
907 request_id: id,
908 reason: EditPredictionRejectReason::Empty,
909 was_shown: false,
910 model_version: None,
911 }]
912 );
913}
914
915#[gpui::test]
916async fn test_interpolated_empty(cx: &mut TestAppContext) {
917 let (ep_store, mut requests) = init_test_with_fake_client(cx);
918 let fs = FakeFs::new(cx.executor());
919 fs.insert_tree(
920 "/root",
921 json!({
922 "foo.md": "Hello!\nHow\nBye\n"
923 }),
924 )
925 .await;
926 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
927
928 let buffer = project
929 .update(cx, |project, cx| {
930 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
931 project.open_buffer(path, cx)
932 })
933 .await
934 .unwrap();
935 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
936 let position = snapshot.anchor_before(language::Point::new(1, 3));
937
938 ep_store.update(cx, |ep_store, cx| {
939 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
940 });
941
942 let (request, respond_tx) = requests.predict.next().await.unwrap();
943
944 buffer.update(cx, |buffer, cx| {
945 buffer.set_text("Hello!\nHow are you?\nBye", cx);
946 });
947
948 let response = model_response(&request, SIMPLE_DIFF);
949 let id = response.request_id.clone();
950 respond_tx.send(response).unwrap();
951
952 cx.run_until_parked();
953
954 ep_store.update(cx, |ep_store, cx| {
955 assert!(
956 ep_store
957 .prediction_at(&buffer, None, &project, cx)
958 .is_none()
959 );
960 });
961
962 // prediction is reported as rejected
963 let (reject_request, _) = requests.reject.next().await.unwrap();
964
965 assert_eq!(
966 &reject_request.rejections,
967 &[EditPredictionRejection {
968 request_id: id,
969 reason: EditPredictionRejectReason::InterpolatedEmpty,
970 was_shown: false,
971 model_version: None,
972 }]
973 );
974}
975
976const SIMPLE_DIFF: &str = indoc! { r"
977 --- a/root/foo.md
978 +++ b/root/foo.md
979 @@ ... @@
980 Hello!
981 -How
982 +How are you?
983 Bye
984"};
985
986#[gpui::test]
987async fn test_replace_current(cx: &mut TestAppContext) {
988 let (ep_store, mut requests) = init_test_with_fake_client(cx);
989 let fs = FakeFs::new(cx.executor());
990 fs.insert_tree(
991 "/root",
992 json!({
993 "foo.md": "Hello!\nHow\nBye\n"
994 }),
995 )
996 .await;
997 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
998
999 let buffer = project
1000 .update(cx, |project, cx| {
1001 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1002 project.open_buffer(path, cx)
1003 })
1004 .await
1005 .unwrap();
1006 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1007 let position = snapshot.anchor_before(language::Point::new(1, 3));
1008
1009 ep_store.update(cx, |ep_store, cx| {
1010 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1011 });
1012
1013 let (request, respond_tx) = requests.predict.next().await.unwrap();
1014 let first_response = model_response(&request, SIMPLE_DIFF);
1015 let first_id = first_response.request_id.clone();
1016 respond_tx.send(first_response).unwrap();
1017
1018 cx.run_until_parked();
1019
1020 ep_store.update(cx, |ep_store, cx| {
1021 assert_eq!(
1022 ep_store
1023 .prediction_at(&buffer, None, &project, cx)
1024 .unwrap()
1025 .id
1026 .0,
1027 first_id
1028 );
1029 });
1030
1031 // a second request is triggered
1032 ep_store.update(cx, |ep_store, cx| {
1033 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1034 });
1035
1036 let (request, respond_tx) = requests.predict.next().await.unwrap();
1037 let second_response = model_response(&request, SIMPLE_DIFF);
1038 let second_id = second_response.request_id.clone();
1039 respond_tx.send(second_response).unwrap();
1040
1041 cx.run_until_parked();
1042
1043 ep_store.update(cx, |ep_store, cx| {
1044 // second replaces first
1045 assert_eq!(
1046 ep_store
1047 .prediction_at(&buffer, None, &project, cx)
1048 .unwrap()
1049 .id
1050 .0,
1051 second_id
1052 );
1053 });
1054
1055 // first is reported as replaced
1056 let (reject_request, _) = requests.reject.next().await.unwrap();
1057
1058 assert_eq!(
1059 &reject_request.rejections,
1060 &[EditPredictionRejection {
1061 request_id: first_id,
1062 reason: EditPredictionRejectReason::Replaced,
1063 was_shown: false,
1064 model_version: None,
1065 }]
1066 );
1067}
1068
1069#[gpui::test]
1070async fn test_current_preferred(cx: &mut TestAppContext) {
1071 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1072 let fs = FakeFs::new(cx.executor());
1073 fs.insert_tree(
1074 "/root",
1075 json!({
1076 "foo.md": "Hello!\nHow\nBye\n"
1077 }),
1078 )
1079 .await;
1080 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1081
1082 let buffer = project
1083 .update(cx, |project, cx| {
1084 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1085 project.open_buffer(path, cx)
1086 })
1087 .await
1088 .unwrap();
1089 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1090 let position = snapshot.anchor_before(language::Point::new(1, 3));
1091
1092 ep_store.update(cx, |ep_store, cx| {
1093 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1094 });
1095
1096 let (request, respond_tx) = requests.predict.next().await.unwrap();
1097 let first_response = model_response(&request, SIMPLE_DIFF);
1098 let first_id = first_response.request_id.clone();
1099 respond_tx.send(first_response).unwrap();
1100
1101 cx.run_until_parked();
1102
1103 ep_store.update(cx, |ep_store, cx| {
1104 assert_eq!(
1105 ep_store
1106 .prediction_at(&buffer, None, &project, cx)
1107 .unwrap()
1108 .id
1109 .0,
1110 first_id
1111 );
1112 });
1113
1114 // a second request is triggered
1115 ep_store.update(cx, |ep_store, cx| {
1116 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1117 });
1118
1119 let (request, respond_tx) = requests.predict.next().await.unwrap();
1120 // worse than current prediction
1121 let second_response = model_response(
1122 &request,
1123 indoc! { r"
1124 --- a/root/foo.md
1125 +++ b/root/foo.md
1126 @@ ... @@
1127 Hello!
1128 -How
1129 +How are
1130 Bye
1131 "},
1132 );
1133 let second_id = second_response.request_id.clone();
1134 respond_tx.send(second_response).unwrap();
1135
1136 cx.run_until_parked();
1137
1138 ep_store.update(cx, |ep_store, cx| {
1139 // first is preferred over second
1140 assert_eq!(
1141 ep_store
1142 .prediction_at(&buffer, None, &project, cx)
1143 .unwrap()
1144 .id
1145 .0,
1146 first_id
1147 );
1148 });
1149
1150 // second is reported as rejected
1151 let (reject_request, _) = requests.reject.next().await.unwrap();
1152
1153 assert_eq!(
1154 &reject_request.rejections,
1155 &[EditPredictionRejection {
1156 request_id: second_id,
1157 reason: EditPredictionRejectReason::CurrentPreferred,
1158 was_shown: false,
1159 model_version: None,
1160 }]
1161 );
1162}
1163
1164#[gpui::test]
1165async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1166 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1167 let fs = FakeFs::new(cx.executor());
1168 fs.insert_tree(
1169 "/root",
1170 json!({
1171 "foo.md": "Hello!\nHow\nBye\n"
1172 }),
1173 )
1174 .await;
1175 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1176
1177 let buffer = project
1178 .update(cx, |project, cx| {
1179 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1180 project.open_buffer(path, cx)
1181 })
1182 .await
1183 .unwrap();
1184 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1185 let position = snapshot.anchor_before(language::Point::new(1, 3));
1186
1187 // start two refresh tasks
1188 ep_store.update(cx, |ep_store, cx| {
1189 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1190 });
1191
1192 let (request1, respond_first) = requests.predict.next().await.unwrap();
1193
1194 ep_store.update(cx, |ep_store, cx| {
1195 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1196 });
1197
1198 let (request, respond_second) = requests.predict.next().await.unwrap();
1199
1200 // wait for throttle
1201 cx.run_until_parked();
1202
1203 // second responds first
1204 let second_response = model_response(&request, SIMPLE_DIFF);
1205 let second_id = second_response.request_id.clone();
1206 respond_second.send(second_response).unwrap();
1207
1208 cx.run_until_parked();
1209
1210 ep_store.update(cx, |ep_store, cx| {
1211 // current prediction is second
1212 assert_eq!(
1213 ep_store
1214 .prediction_at(&buffer, None, &project, cx)
1215 .unwrap()
1216 .id
1217 .0,
1218 second_id
1219 );
1220 });
1221
1222 let first_response = model_response(&request1, SIMPLE_DIFF);
1223 let first_id = first_response.request_id.clone();
1224 respond_first.send(first_response).unwrap();
1225
1226 cx.run_until_parked();
1227
1228 ep_store.update(cx, |ep_store, cx| {
1229 // current prediction is still second, since first was cancelled
1230 assert_eq!(
1231 ep_store
1232 .prediction_at(&buffer, None, &project, cx)
1233 .unwrap()
1234 .id
1235 .0,
1236 second_id
1237 );
1238 });
1239
1240 // first is reported as rejected
1241 let (reject_request, _) = requests.reject.next().await.unwrap();
1242
1243 cx.run_until_parked();
1244
1245 assert_eq!(
1246 &reject_request.rejections,
1247 &[EditPredictionRejection {
1248 request_id: first_id,
1249 reason: EditPredictionRejectReason::Canceled,
1250 was_shown: false,
1251 model_version: None,
1252 }]
1253 );
1254}
1255
1256#[gpui::test]
1257async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1258 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1259 let fs = FakeFs::new(cx.executor());
1260 fs.insert_tree(
1261 "/root",
1262 json!({
1263 "foo.md": "Hello!\nHow\nBye\n"
1264 }),
1265 )
1266 .await;
1267 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1268
1269 let buffer = project
1270 .update(cx, |project, cx| {
1271 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1272 project.open_buffer(path, cx)
1273 })
1274 .await
1275 .unwrap();
1276 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1277 let position = snapshot.anchor_before(language::Point::new(1, 3));
1278
1279 // start two refresh tasks
1280 ep_store.update(cx, |ep_store, cx| {
1281 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1282 });
1283
1284 let (request1, respond_first) = requests.predict.next().await.unwrap();
1285
1286 ep_store.update(cx, |ep_store, cx| {
1287 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1288 });
1289
1290 let (request2, respond_second) = requests.predict.next().await.unwrap();
1291
1292 // wait for throttle, so requests are sent
1293 cx.run_until_parked();
1294
1295 ep_store.update(cx, |ep_store, cx| {
1296 // start a third request
1297 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1298
1299 // 2 are pending, so 2nd is cancelled
1300 assert_eq!(
1301 ep_store
1302 .get_or_init_project(&project, cx)
1303 .cancelled_predictions
1304 .iter()
1305 .copied()
1306 .collect::<Vec<_>>(),
1307 [1]
1308 );
1309 });
1310
1311 // wait for throttle
1312 cx.run_until_parked();
1313
1314 let (request3, respond_third) = requests.predict.next().await.unwrap();
1315
1316 let first_response = model_response(&request1, SIMPLE_DIFF);
1317 let first_id = first_response.request_id.clone();
1318 respond_first.send(first_response).unwrap();
1319
1320 cx.run_until_parked();
1321
1322 ep_store.update(cx, |ep_store, cx| {
1323 // current prediction is first
1324 assert_eq!(
1325 ep_store
1326 .prediction_at(&buffer, None, &project, cx)
1327 .unwrap()
1328 .id
1329 .0,
1330 first_id
1331 );
1332 });
1333
1334 let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1335 let cancelled_id = cancelled_response.request_id.clone();
1336 respond_second.send(cancelled_response).unwrap();
1337
1338 cx.run_until_parked();
1339
1340 ep_store.update(cx, |ep_store, cx| {
1341 // current prediction is still first, since second was cancelled
1342 assert_eq!(
1343 ep_store
1344 .prediction_at(&buffer, None, &project, cx)
1345 .unwrap()
1346 .id
1347 .0,
1348 first_id
1349 );
1350 });
1351
1352 let third_response = model_response(&request3, SIMPLE_DIFF);
1353 let third_response_id = third_response.request_id.clone();
1354 respond_third.send(third_response).unwrap();
1355
1356 cx.run_until_parked();
1357
1358 ep_store.update(cx, |ep_store, cx| {
1359 // third completes and replaces first
1360 assert_eq!(
1361 ep_store
1362 .prediction_at(&buffer, None, &project, cx)
1363 .unwrap()
1364 .id
1365 .0,
1366 third_response_id
1367 );
1368 });
1369
1370 // second is reported as rejected
1371 let (reject_request, _) = requests.reject.next().await.unwrap();
1372
1373 cx.run_until_parked();
1374
1375 assert_eq!(
1376 &reject_request.rejections,
1377 &[
1378 EditPredictionRejection {
1379 request_id: cancelled_id,
1380 reason: EditPredictionRejectReason::Canceled,
1381 was_shown: false,
1382 model_version: None,
1383 },
1384 EditPredictionRejection {
1385 request_id: first_id,
1386 reason: EditPredictionRejectReason::Replaced,
1387 was_shown: false,
1388 model_version: None,
1389 }
1390 ]
1391 );
1392}
1393
1394#[gpui::test]
1395async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
1396 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1397
1398 let fs = FakeFs::new(cx.executor());
1399 fs.insert_tree(
1400 "/root",
1401 json!({
1402 "foo.md": "Hello!\nHow\nBye\n",
1403 "bar.md": "Hola!\nComo\nAdios\n"
1404 }),
1405 )
1406 .await;
1407 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1408
1409 let buffer = project
1410 .update(cx, |project, cx| {
1411 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1412 project.set_active_path(Some(path.clone()), cx);
1413 project.open_buffer(path, cx)
1414 })
1415 .await
1416 .unwrap();
1417 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1418 let position = snapshot.anchor_before(language::Point::new(1, 3));
1419
1420 ep_store.update(cx, |ep_store, cx| {
1421 ep_store.register_project(&project, cx);
1422 ep_store.register_buffer(&buffer, &project, cx);
1423 });
1424
1425 // First edit request - no prior edit, so not throttled.
1426 ep_store.update(cx, |ep_store, cx| {
1427 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1428 });
1429 let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
1430 edit_response_tx.send(empty_response()).unwrap();
1431 cx.run_until_parked();
1432
1433 let diagnostic = lsp::Diagnostic {
1434 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1435 severity: Some(lsp::DiagnosticSeverity::ERROR),
1436 message: "Sentence is incomplete".to_string(),
1437 ..Default::default()
1438 };
1439
1440 // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
1441 project.update(cx, |project, cx| {
1442 project.lsp_store().update(cx, |lsp_store, cx| {
1443 lsp_store
1444 .update_diagnostics(
1445 LanguageServerId(0),
1446 lsp::PublishDiagnosticsParams {
1447 uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
1448 diagnostics: vec![diagnostic],
1449 version: None,
1450 },
1451 None,
1452 language::DiagnosticSourceKind::Pushed,
1453 &[],
1454 cx,
1455 )
1456 .unwrap();
1457 });
1458 });
1459 let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
1460 jump_response_tx.send(empty_response()).unwrap();
1461 cx.run_until_parked();
1462
1463 // Second edit request - should be throttled by the first edit.
1464 ep_store.update(cx, |ep_store, cx| {
1465 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1466 });
1467 assert_no_predict_request_ready(&mut requests.predict);
1468
1469 // Second jump request - should be throttled by the first jump.
1470 ep_store.update(cx, |ep_store, cx| {
1471 ep_store.refresh_prediction_from_diagnostics(
1472 project.clone(),
1473 DiagnosticSearchScope::Global,
1474 cx,
1475 );
1476 });
1477 assert_no_predict_request_ready(&mut requests.predict);
1478
1479 // Wait for both throttles to expire.
1480 cx.background_executor
1481 .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
1482 cx.background_executor.run_until_parked();
1483 cx.run_until_parked();
1484
1485 // Both requests should now go through.
1486 let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
1487 response_tx_1.send(empty_response()).unwrap();
1488 cx.run_until_parked();
1489
1490 let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
1491 response_tx_2.send(empty_response()).unwrap();
1492 cx.run_until_parked();
1493}
1494
1495#[gpui::test]
1496async fn test_same_frame_duplicate_requests_deduplicated(cx: &mut TestAppContext) {
1497 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1498 let fs = FakeFs::new(cx.executor());
1499 fs.insert_tree(
1500 "/root",
1501 json!({
1502 "foo.md": "Hello!\nHow\nBye\n"
1503 }),
1504 )
1505 .await;
1506 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1507
1508 let buffer = project
1509 .update(cx, |project, cx| {
1510 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1511 project.open_buffer(path, cx)
1512 })
1513 .await
1514 .unwrap();
1515 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1516 let position = snapshot.anchor_before(language::Point::new(1, 3));
1517
1518 // Enqueue two refresh calls in the same synchronous frame (no yielding).
1519 // Both `cx.spawn` tasks are created before either executes, so they both
1520 // capture the same `proceed_count_at_enqueue`. Only the first task should
1521 // pass the deduplication gate; the second should be skipped.
1522 ep_store.update(cx, |ep_store, cx| {
1523 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1524 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1525 });
1526
1527 // Let both spawned tasks run to completion (including any throttle waits).
1528 cx.run_until_parked();
1529
1530 // Exactly one prediction request should have been sent.
1531 let (request, respond_tx) = requests.predict.next().await.unwrap();
1532 respond_tx
1533 .send(model_response(&request, SIMPLE_DIFF))
1534 .unwrap();
1535 cx.run_until_parked();
1536
1537 // No second request should be pending.
1538 assert_no_predict_request_ready(&mut requests.predict);
1539}
1540
1541#[gpui::test]
1542async fn test_rejections_flushing(cx: &mut TestAppContext) {
1543 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1544
1545 ep_store.update(cx, |ep_store, cx| {
1546 ep_store.reject_prediction(
1547 EditPredictionId("test-1".into()),
1548 EditPredictionRejectReason::Discarded,
1549 false,
1550 None,
1551 cx,
1552 );
1553 ep_store.reject_prediction(
1554 EditPredictionId("test-2".into()),
1555 EditPredictionRejectReason::Canceled,
1556 true,
1557 None,
1558 cx,
1559 );
1560 });
1561
1562 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1563 cx.run_until_parked();
1564
1565 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1566 respond_tx.send(()).unwrap();
1567
1568 // batched
1569 assert_eq!(reject_request.rejections.len(), 2);
1570 assert_eq!(
1571 reject_request.rejections[0],
1572 EditPredictionRejection {
1573 request_id: "test-1".to_string(),
1574 reason: EditPredictionRejectReason::Discarded,
1575 was_shown: false,
1576 model_version: None,
1577 }
1578 );
1579 assert_eq!(
1580 reject_request.rejections[1],
1581 EditPredictionRejection {
1582 request_id: "test-2".to_string(),
1583 reason: EditPredictionRejectReason::Canceled,
1584 was_shown: true,
1585 model_version: None,
1586 }
1587 );
1588
1589 // Reaching batch size limit sends without debounce
1590 ep_store.update(cx, |ep_store, cx| {
1591 for i in 0..70 {
1592 ep_store.reject_prediction(
1593 EditPredictionId(format!("batch-{}", i).into()),
1594 EditPredictionRejectReason::Discarded,
1595 false,
1596 None,
1597 cx,
1598 );
1599 }
1600 });
1601
1602 // First MAX/2 items are sent immediately
1603 cx.run_until_parked();
1604 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1605 respond_tx.send(()).unwrap();
1606
1607 assert_eq!(reject_request.rejections.len(), 50);
1608 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1609 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1610
1611 // Remaining items are debounced with the next batch
1612 cx.executor().advance_clock(Duration::from_secs(15));
1613 cx.run_until_parked();
1614
1615 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1616 respond_tx.send(()).unwrap();
1617
1618 assert_eq!(reject_request.rejections.len(), 20);
1619 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1620 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1621
1622 // Request failure
1623 ep_store.update(cx, |ep_store, cx| {
1624 ep_store.reject_prediction(
1625 EditPredictionId("retry-1".into()),
1626 EditPredictionRejectReason::Discarded,
1627 false,
1628 None,
1629 cx,
1630 );
1631 });
1632
1633 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1634 cx.run_until_parked();
1635
1636 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1637 assert_eq!(reject_request.rejections.len(), 1);
1638 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1639 // Simulate failure
1640 drop(_respond_tx);
1641
1642 // Add another rejection
1643 ep_store.update(cx, |ep_store, cx| {
1644 ep_store.reject_prediction(
1645 EditPredictionId("retry-2".into()),
1646 EditPredictionRejectReason::Discarded,
1647 false,
1648 None,
1649 cx,
1650 );
1651 });
1652
1653 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1654 cx.run_until_parked();
1655
1656 // Retry should include both the failed item and the new one
1657 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1658 respond_tx.send(()).unwrap();
1659
1660 assert_eq!(reject_request.rejections.len(), 2);
1661 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1662 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1663}
1664
1665#[gpui::test]
1666fn test_active_buffer_diagnostics_fetching(cx: &mut TestAppContext) {
1667 let diagnostic_marker: TextRangeMarker = ('«', '»').into();
1668 let search_range_marker: TextRangeMarker = ('[', ']').into();
1669
1670 let (text, mut ranges) = marked_text_ranges_by(
1671 indoc! {r#"
1672 fn alpha() {
1673 let «first_value» = 1;
1674 }
1675
1676 [fn beta() {
1677 let «second_value» = 2;
1678 let third_value = second_value + missing_symbol;
1679 }ˇ]
1680
1681 fn gamma() {
1682 let «fourth_value» = missing_other_symbol;
1683 }
1684 "#},
1685 vec![diagnostic_marker.clone(), search_range_marker.clone()],
1686 );
1687
1688 let diagnostic_ranges = ranges.remove(&diagnostic_marker).unwrap_or_default();
1689 let search_ranges = ranges.remove(&search_range_marker).unwrap_or_default();
1690
1691 let buffer = cx.new(|cx| Buffer::local(&text, cx));
1692
1693 buffer.update(cx, |buffer, cx| {
1694 let snapshot = buffer.snapshot();
1695 let diagnostics = DiagnosticSet::new(
1696 diagnostic_ranges
1697 .iter()
1698 .enumerate()
1699 .map(|(index, range)| DiagnosticEntry {
1700 range: snapshot.offset_to_point_utf16(range.start)
1701 ..snapshot.offset_to_point_utf16(range.end),
1702 diagnostic: Diagnostic {
1703 severity: match index {
1704 0 => DiagnosticSeverity::WARNING,
1705 1 => DiagnosticSeverity::ERROR,
1706 _ => DiagnosticSeverity::HINT,
1707 },
1708 message: match index {
1709 0 => "first warning".to_string(),
1710 1 => "second error".to_string(),
1711 _ => "third hint".to_string(),
1712 },
1713 group_id: index + 1,
1714 is_primary: true,
1715 source_kind: language::DiagnosticSourceKind::Pushed,
1716 ..Diagnostic::default()
1717 },
1718 }),
1719 &snapshot,
1720 );
1721 buffer.update_diagnostics(LanguageServerId(0), diagnostics, cx);
1722 });
1723
1724 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1725 let search_range = snapshot.offset_to_point(search_ranges[0].start)
1726 ..snapshot.offset_to_point(search_ranges[0].end);
1727
1728 let active_buffer_diagnostics = zeta::active_buffer_diagnostics(&snapshot, search_range, 100);
1729
1730 assert_eq!(
1731 active_buffer_diagnostics,
1732 vec![zeta_prompt::ActiveBufferDiagnostic {
1733 severity: Some(1),
1734 message: "second error".to_string(),
1735 snippet: text,
1736 snippet_buffer_row_range: 5..5,
1737 diagnostic_range_in_snippet: 61..73,
1738 }]
1739 );
1740
1741 let buffer = cx.new(|cx| {
1742 Buffer::local(
1743 indoc! {"
1744 one
1745 two
1746 three
1747 four
1748 five
1749 "},
1750 cx,
1751 )
1752 });
1753
1754 buffer.update(cx, |buffer, cx| {
1755 let snapshot = buffer.snapshot();
1756 let diagnostics = DiagnosticSet::new(
1757 vec![
1758 DiagnosticEntry {
1759 range: text::PointUtf16::new(0, 0)..text::PointUtf16::new(0, 3),
1760 diagnostic: Diagnostic {
1761 severity: DiagnosticSeverity::ERROR,
1762 message: "row zero".to_string(),
1763 group_id: 1,
1764 is_primary: true,
1765 source_kind: language::DiagnosticSourceKind::Pushed,
1766 ..Diagnostic::default()
1767 },
1768 },
1769 DiagnosticEntry {
1770 range: text::PointUtf16::new(2, 0)..text::PointUtf16::new(2, 5),
1771 diagnostic: Diagnostic {
1772 severity: DiagnosticSeverity::WARNING,
1773 message: "row two".to_string(),
1774 group_id: 2,
1775 is_primary: true,
1776 source_kind: language::DiagnosticSourceKind::Pushed,
1777 ..Diagnostic::default()
1778 },
1779 },
1780 DiagnosticEntry {
1781 range: text::PointUtf16::new(4, 0)..text::PointUtf16::new(4, 4),
1782 diagnostic: Diagnostic {
1783 severity: DiagnosticSeverity::INFORMATION,
1784 message: "row four".to_string(),
1785 group_id: 3,
1786 is_primary: true,
1787 source_kind: language::DiagnosticSourceKind::Pushed,
1788 ..Diagnostic::default()
1789 },
1790 },
1791 ],
1792 &snapshot,
1793 );
1794 buffer.update_diagnostics(LanguageServerId(0), diagnostics, cx);
1795 });
1796
1797 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1798
1799 let active_buffer_diagnostics =
1800 zeta::active_buffer_diagnostics(&snapshot, Point::new(2, 0)..Point::new(4, 0), 100);
1801
1802 assert_eq!(
1803 active_buffer_diagnostics
1804 .iter()
1805 .map(|diagnostic| (
1806 diagnostic.severity,
1807 diagnostic.message.clone(),
1808 diagnostic.snippet.clone(),
1809 diagnostic.snippet_buffer_row_range.clone(),
1810 diagnostic.diagnostic_range_in_snippet.clone(),
1811 ))
1812 .collect::<Vec<_>>(),
1813 vec![
1814 (
1815 Some(2),
1816 "row two".to_string(),
1817 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1818 2..2,
1819 8..13,
1820 ),
1821 (
1822 Some(3),
1823 "row four".to_string(),
1824 "one\ntwo\nthree\nfour\nfive\n".to_string(),
1825 4..4,
1826 19..23,
1827 ),
1828 ]
1829 );
1830}
1831
1832// Generate a model response that would apply the given diff to the active file.
1833fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1834 let editable_range =
1835 zeta_prompt::excerpt_range_for_format(Default::default(), &request.input.excerpt_ranges).1;
1836 let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string();
1837 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1838
1839 PredictEditsV3Response {
1840 request_id: Uuid::new_v4().to_string(),
1841 editable_range,
1842 output: new_excerpt,
1843 model_version: None,
1844 }
1845}
1846
1847fn empty_response() -> PredictEditsV3Response {
1848 PredictEditsV3Response {
1849 request_id: Uuid::new_v4().to_string(),
1850 editable_range: 0..0,
1851 output: String::new(),
1852 model_version: None,
1853 }
1854}
1855
1856fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1857 zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1858}
1859
1860fn assert_no_predict_request_ready(
1861 requests: &mut mpsc::UnboundedReceiver<(
1862 PredictEditsV3Request,
1863 oneshot::Sender<PredictEditsV3Response>,
1864 )>,
1865) {
1866 if requests.next().now_or_never().flatten().is_some() {
1867 panic!("Unexpected prediction request while throttled.");
1868 }
1869}
1870
1871struct RequestChannels {
1872 predict: mpsc::UnboundedReceiver<(
1873 PredictEditsV3Request,
1874 oneshot::Sender<PredictEditsV3Response>,
1875 )>,
1876 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1877}
1878
1879fn init_test_with_fake_client(
1880 cx: &mut TestAppContext,
1881) -> (Entity<EditPredictionStore>, RequestChannels) {
1882 cx.update(move |cx| {
1883 let settings_store = SettingsStore::test(cx);
1884 cx.set_global(settings_store);
1885 zlog::init_test();
1886
1887 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1888 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1889
1890 let http_client = FakeHttpClient::create({
1891 move |req| {
1892 let uri = req.uri().path().to_string();
1893 let mut body = req.into_body();
1894 let predict_req_tx = predict_req_tx.clone();
1895 let reject_req_tx = reject_req_tx.clone();
1896 async move {
1897 let resp = match uri.as_str() {
1898 "/client/llm_tokens" => serde_json::to_string(&json!({
1899 "token": "test"
1900 }))
1901 .unwrap(),
1902 "/predict_edits/v3" => {
1903 let mut buf = Vec::new();
1904 body.read_to_end(&mut buf).await.ok();
1905 let decompressed = zstd::decode_all(&buf[..]).unwrap();
1906 let req = serde_json::from_slice(&decompressed).unwrap();
1907
1908 let (res_tx, res_rx) = oneshot::channel();
1909 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1910 serde_json::to_string(&res_rx.await?).unwrap()
1911 }
1912 "/predict_edits/reject" => {
1913 let mut buf = Vec::new();
1914 body.read_to_end(&mut buf).await.ok();
1915 let req = serde_json::from_slice(&buf).unwrap();
1916
1917 let (res_tx, res_rx) = oneshot::channel();
1918 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1919 serde_json::to_string(&res_rx.await?).unwrap()
1920 }
1921 _ => {
1922 panic!("Unexpected path: {}", uri)
1923 }
1924 };
1925
1926 Ok(Response::builder().body(resp.into()).unwrap())
1927 }
1928 }
1929 });
1930
1931 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1932 client.cloud_client().set_credentials(1, "test".into());
1933
1934 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1935 language_model::init(user_store.clone(), client.clone(), cx);
1936 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1937
1938 (
1939 ep_store,
1940 RequestChannels {
1941 predict: predict_req_rx,
1942 reject: reject_req_rx,
1943 },
1944 )
1945 })
1946}
1947
1948#[gpui::test]
1949async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1950 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1951 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1952 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1953 });
1954
1955 let edit_preview = cx
1956 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1957 .await;
1958
1959 let prediction = EditPrediction {
1960 edits,
1961 cursor_position: None,
1962 edit_preview,
1963 buffer: buffer.clone(),
1964 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1965 id: EditPredictionId("the-id".into()),
1966 inputs: ZetaPromptInput {
1967 events: Default::default(),
1968 related_files: Default::default(),
1969 active_buffer_diagnostics: vec![],
1970 cursor_path: Path::new("").into(),
1971 cursor_excerpt: "".into(),
1972 cursor_offset_in_excerpt: 0,
1973 excerpt_start_row: None,
1974 excerpt_ranges: Default::default(),
1975 syntax_ranges: None,
1976 experiment: None,
1977 in_open_source_repo: false,
1978 can_collect_data: false,
1979 repo_url: None,
1980 },
1981 buffer_snapshotted_at: Instant::now(),
1982 response_received_at: Instant::now(),
1983 model_version: None,
1984 };
1985
1986 cx.update(|cx| {
1987 assert_eq!(
1988 from_completion_edits(
1989 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1990 &buffer,
1991 cx
1992 ),
1993 vec![(2..5, "REM".into()), (9..11, "".into())]
1994 );
1995
1996 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1997 assert_eq!(
1998 from_completion_edits(
1999 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2000 &buffer,
2001 cx
2002 ),
2003 vec![(2..2, "REM".into()), (6..8, "".into())]
2004 );
2005
2006 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2007 assert_eq!(
2008 from_completion_edits(
2009 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2010 &buffer,
2011 cx
2012 ),
2013 vec![(2..5, "REM".into()), (9..11, "".into())]
2014 );
2015
2016 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
2017 assert_eq!(
2018 from_completion_edits(
2019 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2020 &buffer,
2021 cx
2022 ),
2023 vec![(3..3, "EM".into()), (7..9, "".into())]
2024 );
2025
2026 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
2027 assert_eq!(
2028 from_completion_edits(
2029 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2030 &buffer,
2031 cx
2032 ),
2033 vec![(4..4, "M".into()), (8..10, "".into())]
2034 );
2035
2036 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2037 assert_eq!(
2038 from_completion_edits(
2039 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2040 &buffer,
2041 cx
2042 ),
2043 vec![(9..11, "".into())]
2044 );
2045
2046 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2047 assert_eq!(
2048 from_completion_edits(
2049 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2050 &buffer,
2051 cx
2052 ),
2053 vec![(4..4, "M".into()), (8..10, "".into())]
2054 );
2055
2056 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2057 assert_eq!(
2058 from_completion_edits(
2059 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2060 &buffer,
2061 cx
2062 ),
2063 vec![(4..4, "M".into())]
2064 );
2065
2066 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2067 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
2068 })
2069}
2070
2071#[gpui::test]
2072async fn test_clean_up_diff(cx: &mut TestAppContext) {
2073 init_test(cx);
2074
2075 assert_eq!(
2076 apply_edit_prediction(
2077 indoc! {"
2078 fn main() {
2079 let word_1 = \"lorem\";
2080 let range = word.len()..word.len();
2081 }
2082 "},
2083 indoc! {"
2084 fn main() {
2085 let word_1 = \"lorem\";
2086 let range = word_1.len()..word_1.len();
2087 }
2088 "},
2089 cx,
2090 )
2091 .await,
2092 indoc! {"
2093 fn main() {
2094 let word_1 = \"lorem\";
2095 let range = word_1.len()..word_1.len();
2096 }
2097 "},
2098 );
2099
2100 assert_eq!(
2101 apply_edit_prediction(
2102 indoc! {"
2103 fn main() {
2104 let story = \"the quick\"
2105 }
2106 "},
2107 indoc! {"
2108 fn main() {
2109 let story = \"the quick brown fox jumps over the lazy dog\";
2110 }
2111 "},
2112 cx,
2113 )
2114 .await,
2115 indoc! {"
2116 fn main() {
2117 let story = \"the quick brown fox jumps over the lazy dog\";
2118 }
2119 "},
2120 );
2121}
2122
2123#[gpui::test]
2124async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2125 init_test(cx);
2126
2127 let buffer_content = "lorem\n";
2128 let completion_response = "lorem\nipsum\n";
2129
2130 assert_eq!(
2131 apply_edit_prediction(buffer_content, completion_response, cx).await,
2132 "lorem\nipsum\n"
2133 );
2134}
2135
2136#[gpui::test]
2137async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
2138 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
2139 // When the buffer ends without a trailing newline, but the model returns output
2140 // with a trailing newline, zeta2 should normalize both sides before diffing
2141 // so no spurious newline is inserted.
2142 let (ep_store, mut requests) = init_test_with_fake_client(cx);
2143 let fs = FakeFs::new(cx.executor());
2144
2145 // Single line buffer with no trailing newline
2146 fs.insert_tree(
2147 "/root",
2148 json!({
2149 "foo.txt": "hello"
2150 }),
2151 )
2152 .await;
2153 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2154
2155 let buffer = project
2156 .update(cx, |project, cx| {
2157 let path = project
2158 .find_project_path(path!("root/foo.txt"), cx)
2159 .unwrap();
2160 project.open_buffer(path, cx)
2161 })
2162 .await
2163 .unwrap();
2164
2165 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2166 let position = snapshot.anchor_before(language::Point::new(0, 5));
2167
2168 ep_store.update(cx, |ep_store, cx| {
2169 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2170 });
2171
2172 let (request, respond_tx) = requests.predict.next().await.unwrap();
2173
2174 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
2175 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
2176 let excerpt_length = request.input.cursor_excerpt.len();
2177 let response = PredictEditsV3Response {
2178 request_id: Uuid::new_v4().to_string(),
2179 output: "hello world\n".to_string(),
2180 editable_range: 0..excerpt_length,
2181 model_version: None,
2182 };
2183 respond_tx.send(response).unwrap();
2184
2185 cx.run_until_parked();
2186
2187 // The prediction should insert " world" without adding a newline
2188 ep_store.update(cx, |ep_store, cx| {
2189 let prediction = ep_store
2190 .prediction_at(&buffer, None, &project, cx)
2191 .expect("should have prediction");
2192 let edits: Vec<_> = prediction
2193 .edits
2194 .iter()
2195 .map(|(range, text)| {
2196 let snapshot = buffer.read(cx).snapshot();
2197 (range.to_offset(&snapshot), text.clone())
2198 })
2199 .collect();
2200 assert_eq!(edits, vec![(5..5, " world".into())]);
2201 });
2202}
2203
2204fn init_test(cx: &mut TestAppContext) {
2205 cx.update(|cx| {
2206 let settings_store = SettingsStore::test(cx);
2207 cx.set_global(settings_store);
2208 });
2209}
2210
2211async fn apply_edit_prediction(
2212 buffer_content: &str,
2213 completion_response: &str,
2214 cx: &mut TestAppContext,
2215) -> String {
2216 let fs = project::FakeFs::new(cx.executor());
2217 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2218 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2219 let (ep_store, response) = make_test_ep_store(&project, cx).await;
2220 *response.lock() = completion_response.to_string();
2221 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2222 buffer.update(cx, |buffer, cx| {
2223 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2224 });
2225 buffer.read_with(cx, |buffer, _| buffer.text())
2226}
2227
2228async fn run_edit_prediction(
2229 buffer: &Entity<Buffer>,
2230 project: &Entity<Project>,
2231 ep_store: &Entity<EditPredictionStore>,
2232 cx: &mut TestAppContext,
2233) -> EditPrediction {
2234 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2235 ep_store.update(cx, |ep_store, cx| {
2236 ep_store.register_buffer(buffer, &project, cx)
2237 });
2238 cx.background_executor.run_until_parked();
2239 let prediction_task = ep_store.update(cx, |ep_store, cx| {
2240 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2241 });
2242 prediction_task.await.unwrap().unwrap().prediction.unwrap()
2243}
2244
2245async fn make_test_ep_store(
2246 project: &Entity<Project>,
2247 cx: &mut TestAppContext,
2248) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
2249 let default_response = "hello world\n".to_string();
2250 let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
2251 let http_client = FakeHttpClient::create({
2252 let completion_response = completion_response.clone();
2253 let mut next_request_id = 0;
2254 move |req| {
2255 let completion_response = completion_response.clone();
2256 let method = req.method().clone();
2257 let uri = req.uri().path().to_string();
2258 let mut body = req.into_body();
2259 async move {
2260 match (method, uri.as_str()) {
2261 (Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2262 .status(200)
2263 .body(
2264 serde_json::to_string(&CreateLlmTokenResponse {
2265 token: LlmToken("the-llm-token".to_string()),
2266 })
2267 .unwrap()
2268 .into(),
2269 )
2270 .unwrap()),
2271 (Method::POST, "/predict_edits/v3") => {
2272 let mut buf = Vec::new();
2273 body.read_to_end(&mut buf).await.ok();
2274 let decompressed = zstd::decode_all(&buf[..]).unwrap();
2275 let req: PredictEditsV3Request =
2276 serde_json::from_slice(&decompressed).unwrap();
2277
2278 next_request_id += 1;
2279 Ok(http_client::Response::builder()
2280 .status(200)
2281 .body(
2282 serde_json::to_string(&PredictEditsV3Response {
2283 request_id: format!("request-{next_request_id}"),
2284 editable_range: 0..req.input.cursor_excerpt.len(),
2285 output: completion_response.lock().clone(),
2286 model_version: None,
2287 })
2288 .unwrap()
2289 .into(),
2290 )
2291 .unwrap())
2292 }
2293 _ => Ok(http_client::Response::builder()
2294 .status(404)
2295 .body("Not Found".to_string().into())
2296 .unwrap()),
2297 }
2298 }
2299 }
2300 });
2301
2302 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2303 let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
2304 cx.update(|cx| {
2305 RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
2306 });
2307 let _server = FakeServer::for_client(42, &client, cx).await;
2308
2309 let ep_store = cx.new(|cx| {
2310 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2311 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2312
2313 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2314 for worktree in worktrees {
2315 let worktree_id = worktree.read(cx).id();
2316 ep_store
2317 .get_or_init_project(project, cx)
2318 .license_detection_watchers
2319 .entry(worktree_id)
2320 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2321 }
2322
2323 ep_store
2324 });
2325
2326 (ep_store, completion_response)
2327}
2328
2329fn to_completion_edits(
2330 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2331 buffer: &Entity<Buffer>,
2332 cx: &App,
2333) -> Vec<(Range<Anchor>, Arc<str>)> {
2334 let buffer = buffer.read(cx);
2335 iterator
2336 .into_iter()
2337 .map(|(range, text)| {
2338 (
2339 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2340 text,
2341 )
2342 })
2343 .collect()
2344}
2345
2346fn from_completion_edits(
2347 editor_edits: &[(Range<Anchor>, Arc<str>)],
2348 buffer: &Entity<Buffer>,
2349 cx: &App,
2350) -> Vec<(Range<usize>, Arc<str>)> {
2351 let buffer = buffer.read(cx);
2352 editor_edits
2353 .iter()
2354 .map(|(range, text)| {
2355 (
2356 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2357 text.clone(),
2358 )
2359 })
2360 .collect()
2361}
2362
2363#[gpui::test]
2364async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2365 init_test(cx);
2366
2367 let fs = FakeFs::new(cx.executor());
2368 fs.insert_tree(
2369 "/project",
2370 serde_json::json!({
2371 "main.rs": "fn main() {\n \n}\n"
2372 }),
2373 )
2374 .await;
2375
2376 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2377
2378 let http_client = FakeHttpClient::create(|_req| async move {
2379 Ok(gpui::http_client::Response::builder()
2380 .status(401)
2381 .body("Unauthorized".into())
2382 .unwrap())
2383 });
2384
2385 let client =
2386 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2387 let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
2388 cx.update(|cx| {
2389 language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
2390 });
2391
2392 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2393
2394 let buffer = project
2395 .update(cx, |project, cx| {
2396 let path = project
2397 .find_project_path(path!("/project/main.rs"), cx)
2398 .unwrap();
2399 project.open_buffer(path, cx)
2400 })
2401 .await
2402 .unwrap();
2403
2404 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2405 ep_store.update(cx, |ep_store, cx| {
2406 ep_store.register_buffer(&buffer, &project, cx)
2407 });
2408 cx.background_executor.run_until_parked();
2409
2410 let completion_task = ep_store.update(cx, |ep_store, cx| {
2411 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2412 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2413 });
2414
2415 let result = completion_task.await;
2416 assert!(
2417 result.is_err(),
2418 "Without authentication and without custom URL, prediction should fail"
2419 );
2420}
2421
2422#[gpui::test]
2423fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2424 let buffer = cx.new(|cx| {
2425 Buffer::local(
2426 indoc! {"
2427 zero
2428 one
2429 two
2430 three
2431 four
2432 five
2433 six
2434 seven
2435 eight
2436 nine
2437 ten
2438 eleven
2439 twelve
2440 thirteen
2441 fourteen
2442 fifteen
2443 sixteen
2444 seventeen
2445 eighteen
2446 nineteen
2447 twenty
2448 twenty-one
2449 twenty-two
2450 twenty-three
2451 twenty-four
2452 "},
2453 cx,
2454 )
2455 });
2456
2457 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2458
2459 buffer.update(cx, |buffer, cx| {
2460 let point = Point::new(12, 0);
2461 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2462 let point = Point::new(8, 0);
2463 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2464 });
2465
2466 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2467
2468 let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2469
2470 assert_eq!(
2471 diff,
2472 indoc! {"
2473 @@ -6,10 +6,12 @@
2474 five
2475 six
2476 seven
2477 +FIRST INSERTION
2478 eight
2479 nine
2480 ten
2481 eleven
2482 +SECOND INSERTION
2483 twelve
2484 thirteen
2485 fourteen
2486 "}
2487 );
2488}
2489
2490#[gpui::test]
2491async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
2492 fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
2493 let collab_replica = clock::ReplicaId::new(10);
2494 let anchor = buffer.read_with(cx, |buffer, _| {
2495 buffer.snapshot().anchor_before(Point::new(row, 0))
2496 });
2497 let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
2498 id: 1,
2499 start: anchor,
2500 end: anchor,
2501 reversed: false,
2502 goal: SelectionGoal::None,
2503 }]);
2504 buffer.update(cx, |buffer, cx| {
2505 buffer.apply_ops(
2506 [Operation::UpdateSelections {
2507 selections,
2508 lamport_timestamp: clock::Lamport {
2509 replica_id: collab_replica,
2510 value: 1,
2511 },
2512 line_mode: false,
2513 cursor_shape: CursorShape::Bar,
2514 }],
2515 cx,
2516 );
2517 });
2518 }
2519
2520 fn publish_diagnostics(
2521 uri_path: &'static str,
2522 rows: &[u32],
2523 project: &Entity<Project>,
2524 cx: &mut TestAppContext,
2525 ) {
2526 let diagnostics: Vec<_> = rows
2527 .iter()
2528 .map(|&row| lsp::Diagnostic {
2529 range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
2530 severity: Some(lsp::DiagnosticSeverity::ERROR),
2531 message: format!("error at row {row}"),
2532 ..Default::default()
2533 })
2534 .collect();
2535 project.update(cx, |project, cx| {
2536 project.lsp_store().update(cx, |lsp_store, cx| {
2537 lsp_store
2538 .update_diagnostics(
2539 LanguageServerId(0),
2540 lsp::PublishDiagnosticsParams {
2541 uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
2542 diagnostics,
2543 version: None,
2544 },
2545 None,
2546 language::DiagnosticSourceKind::Pushed,
2547 &[],
2548 cx,
2549 )
2550 .expect("failed to update diagnostics");
2551 });
2552 });
2553 }
2554
2555 init_test(cx);
2556
2557 let mut lines = String::new();
2558 for i in 0..60 {
2559 lines.push_str(&format!("line {i}\n"));
2560 }
2561
2562 let fs = FakeFs::new(cx.executor());
2563 fs.insert_tree(
2564 "/root",
2565 json!({
2566 "active.txt": lines,
2567 "collab_file.txt": "error here\nsecond line\n",
2568 "free_file.txt": "another error\nsecond line\n",
2569 }),
2570 )
2571 .await;
2572 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2573
2574 let active_buffer = project
2575 .update(cx, |project, cx| {
2576 let path = project
2577 .find_project_path(path!("/root/active.txt"), cx)
2578 .expect("active.txt not found");
2579 project.set_active_path(Some(path.clone()), cx);
2580 project.open_buffer(path, cx)
2581 })
2582 .await
2583 .expect("failed to open active buffer");
2584
2585 set_collaborator_cursor(&active_buffer, 5, cx);
2586
2587 publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
2588
2589 cx.run_until_parked();
2590
2591 let cursor_point = Point::new(25, 0);
2592 let empty_search_range: Range<Point> = Default::default();
2593
2594 let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2595 let result = EditPredictionStore::next_diagnostic_location(
2596 active_buffer.clone(),
2597 &snapshot,
2598 empty_search_range.clone(),
2599 cursor_point,
2600 &project,
2601 &mut cx.to_async(),
2602 )
2603 .await
2604 .expect("next_diagnostic_location failed");
2605
2606 let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
2607 assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
2608 let result_row = result_buffer.read_with(cx, |buffer, _| {
2609 result_anchor.to_point(&buffer.snapshot()).row
2610 });
2611 assert_ne!(
2612 result_row, 3,
2613 "row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
2614 );
2615 assert!(
2616 result_row == 25 || result_row == 50,
2617 "expected row 25 or 50, got {result_row}"
2618 );
2619
2620 let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2621 let near_cursor_point = Point::new(4, 0);
2622 let result_near = EditPredictionStore::next_diagnostic_location(
2623 active_buffer.clone(),
2624 &snapshot_near,
2625 empty_search_range.clone(),
2626 near_cursor_point,
2627 &project,
2628 &mut cx.to_async(),
2629 )
2630 .await
2631 .expect("next_diagnostic_location failed");
2632
2633 let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
2634 let near_row =
2635 active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
2636 assert_eq!(
2637 near_row, 3,
2638 "row 3 should be included when local cursor (row 4) is also near the collaborator"
2639 );
2640
2641 let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2642 let far_cursor_point = Point::new(50, 0);
2643 let result_far = EditPredictionStore::next_diagnostic_location(
2644 active_buffer.clone(),
2645 &snapshot_far,
2646 empty_search_range.clone(),
2647 far_cursor_point,
2648 &project,
2649 &mut cx.to_async(),
2650 )
2651 .await
2652 .expect("next_diagnostic_location failed");
2653
2654 let (_, far_anchor) = result_far.expect("expected a diagnostic location");
2655 let far_row =
2656 active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
2657 assert_eq!(
2658 far_row, 50,
2659 "row 50 is near local cursor (row 50) and far from collaborator, should be picked"
2660 );
2661
2662 publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
2663 publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
2664 cx.run_until_parked();
2665
2666 let collab_buffer = project
2667 .update(cx, |project, cx| {
2668 let path = project
2669 .find_project_path(path!("/root/collab_file.txt"), cx)
2670 .expect("collab_file.txt not found");
2671 project.open_buffer(path, cx)
2672 })
2673 .await
2674 .expect("failed to open collab buffer");
2675
2676 set_collaborator_cursor(&collab_buffer, 0, cx);
2677 cx.run_until_parked();
2678
2679 let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
2680 let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2681 let result_cross = EditPredictionStore::next_diagnostic_location(
2682 active_buffer.clone(),
2683 &snapshot_cross,
2684 no_same_file_search_range,
2685 Point::new(0, 0),
2686 &project,
2687 &mut cx.to_async(),
2688 )
2689 .await
2690 .expect("cross-file next_diagnostic_location failed");
2691
2692 let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
2693 let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
2694 buffer
2695 .file()
2696 .expect("buffer should have a file")
2697 .full_path(cx)
2698 });
2699 assert_eq!(
2700 cross_path,
2701 Path::new(path!("root/free_file.txt")),
2702 "should skip collab_file.txt (has collaborator) and pick free_file.txt"
2703 );
2704}
2705
2706#[gpui::test]
2707async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
2708 let (ep_store, _requests) = init_test_with_fake_client(cx);
2709 let fs = FakeFs::new(cx.executor());
2710
2711 // Buffer with two clearly separated regions:
2712 // Region A = lines 0-9 (offsets 0..50)
2713 // Region B = lines 20-29 (offsets 105..155)
2714 // A big gap in between so edits in one region never overlap the other.
2715 let mut content = String::new();
2716 for i in 0..30 {
2717 content.push_str(&format!("line {i:02}\n"));
2718 }
2719
2720 fs.insert_tree(
2721 "/root",
2722 json!({
2723 "foo.md": content.clone()
2724 }),
2725 )
2726 .await;
2727 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2728
2729 let buffer = project
2730 .update(cx, |project, cx| {
2731 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2732 project.open_buffer(path, cx)
2733 })
2734 .await
2735 .unwrap();
2736
2737 type SettledEventRecord = (EditPredictionId, String);
2738 let settled_events: Arc<Mutex<Vec<SettledEventRecord>>> = Arc::new(Mutex::new(Vec::new()));
2739
2740 ep_store.update(cx, |ep_store, cx| {
2741 ep_store.register_buffer(&buffer, &project, cx);
2742
2743 let settled_events = settled_events.clone();
2744 ep_store.settled_event_callback = Some(Box::new(move |id, text| {
2745 settled_events.lock().push((id, text));
2746 }));
2747 });
2748
2749 // --- Phase 1: edit in region A and enqueue prediction A ---
2750
2751 buffer.update(cx, |buffer, cx| {
2752 // Edit at the start of line 0.
2753 buffer.edit(vec![(0..0, "ADDED ")], None, cx);
2754 });
2755 cx.run_until_parked();
2756
2757 let snapshot_a = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2758
2759 // Region A: first 10 lines of the buffer.
2760 let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0));
2761
2762 ep_store.update(cx, |ep_store, cx| {
2763 ep_store.enqueue_settled_prediction(
2764 EditPredictionId("prediction-a".into()),
2765 &project,
2766 &buffer,
2767 &snapshot_a,
2768 editable_region_a.clone(),
2769 None,
2770 cx,
2771 );
2772 });
2773
2774 // --- Phase 2: repeatedly edit in region A to keep it unsettled ---
2775
2776 // Let the worker process the channel message before we start advancing.
2777 cx.run_until_parked();
2778
2779 let mut region_a_edit_offset = 5;
2780 for _ in 0..3 {
2781 // Edit inside region A (not at the boundary) so `last_edit_at` is
2782 // updated before the worker's next wake.
2783 buffer.update(cx, |buffer, cx| {
2784 buffer.edit(
2785 vec![(region_a_edit_offset..region_a_edit_offset, "x")],
2786 None,
2787 cx,
2788 );
2789 });
2790 region_a_edit_offset += 1;
2791 cx.run_until_parked();
2792
2793 cx.executor()
2794 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 2);
2795 cx.run_until_parked();
2796 assert!(
2797 settled_events.lock().is_empty(),
2798 "no settled events should fire while region A is still being edited"
2799 );
2800 }
2801
2802 // Still nothing settled.
2803 assert!(settled_events.lock().is_empty());
2804
2805 // --- Phase 3: edit in distinct region B, enqueue prediction B ---
2806 // Advance a small amount so B's quiescence window starts later than A's,
2807 // but not so much that A settles (A's last edit was at the start of
2808 // iteration 3, and it needs a full Q to settle).
2809 cx.executor()
2810 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2811 cx.run_until_parked();
2812 assert!(settled_events.lock().is_empty());
2813
2814 let snapshot_b = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2815 let line_20_offset = snapshot_b.point_to_offset(Point::new(20, 0));
2816
2817 buffer.update(cx, |buffer, cx| {
2818 buffer.edit(vec![(line_20_offset..line_20_offset, "NEW ")], None, cx);
2819 });
2820 cx.run_until_parked();
2821
2822 let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2823 let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0));
2824
2825 ep_store.update(cx, |ep_store, cx| {
2826 ep_store.enqueue_settled_prediction(
2827 EditPredictionId("prediction-b".into()),
2828 &project,
2829 &buffer,
2830 &snapshot_b2,
2831 editable_region_b.clone(),
2832 None,
2833 cx,
2834 );
2835 });
2836
2837 cx.run_until_parked();
2838 assert!(
2839 settled_events.lock().is_empty(),
2840 "neither prediction should have settled yet"
2841 );
2842
2843 // --- Phase 4: let enough time pass for region A to settle ---
2844 // A's last edit was at T_a (during the last loop iteration). The worker is
2845 // sleeping until T_a + Q. We advance just enough to reach that wake time
2846 // (Q/4 since we already advanced Q/4 in phase 3 on top of the loop's
2847 // 3*Q/2). At that point A has been quiet for Q and settles, but B was
2848 // enqueued only Q/4 ago and stays pending.
2849 cx.executor()
2850 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2851 cx.run_until_parked();
2852
2853 {
2854 let events = settled_events.lock().clone();
2855 assert_eq!(
2856 events.len(),
2857 1,
2858 "prediction and capture_sample for A should have settled, got: {events:?}"
2859 );
2860 assert_eq!(events[0].0, EditPredictionId("prediction-a".into()));
2861 }
2862
2863 // --- Phase 5: let more time pass for region B to settle ---
2864 // B's last edit was Q/4 before A settled. The worker rescheduled to
2865 // B's last_edit_at + Q, which is 3Q/4 from now.
2866 cx.executor()
2867 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE * 3 / 4);
2868 cx.run_until_parked();
2869
2870 {
2871 let events = settled_events.lock().clone();
2872 assert_eq!(
2873 events.len(),
2874 2,
2875 "both prediction and capture_sample settled events should be emitted for each request, got: {events:?}"
2876 );
2877 assert_eq!(events[1].0, EditPredictionId("prediction-b".into()));
2878 }
2879}
2880
2881#[ctor::ctor]
2882fn init_logger() {
2883 zlog::init_test();
2884}