1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
3use client::{UserStore, test::FakeServer};
4use clock::FakeSystemClock;
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
8 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
9};
10use futures::{
11 AsyncReadExt, FutureExt, StreamExt,
12 channel::{mpsc, oneshot},
13};
14use gpui::App;
15use gpui::{
16 Entity, TestAppContext,
17 http_client::{FakeHttpClient, Response},
18};
19use indoc::indoc;
20use language::{Anchor, Buffer, CursorShape, Operation, Point, Selection, SelectionGoal};
21use lsp::LanguageServerId;
22use parking_lot::Mutex;
23use pretty_assertions::{assert_eq, assert_matches};
24use project::{FakeFs, Project};
25use serde_json::json;
26use settings::SettingsStore;
27use std::{path::Path, sync::Arc, time::Duration};
28use util::path;
29use uuid::Uuid;
30use zeta_prompt::ZetaPromptInput;
31
32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
33
34#[gpui::test]
35async fn test_current_state(cx: &mut TestAppContext) {
36 let (ep_store, mut requests) = init_test_with_fake_client(cx);
37 let fs = FakeFs::new(cx.executor());
38 fs.insert_tree(
39 "/root",
40 json!({
41 "1.txt": "Hello!\nHow\nBye\n",
42 "2.txt": "Hola!\nComo\nAdios\n"
43 }),
44 )
45 .await;
46 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
47
48 let buffer1 = project
49 .update(cx, |project, cx| {
50 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
51 project.set_active_path(Some(path.clone()), cx);
52 project.open_buffer(path, cx)
53 })
54 .await
55 .unwrap();
56 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
57 let position = snapshot1.anchor_before(language::Point::new(1, 3));
58
59 ep_store.update(cx, |ep_store, cx| {
60 ep_store.register_project(&project, cx);
61 ep_store.register_buffer(&buffer1, &project, cx);
62 });
63
64 // Prediction for current file
65
66 ep_store.update(cx, |ep_store, cx| {
67 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
68 });
69 let (request, respond_tx) = requests.predict.next().await.unwrap();
70
71 respond_tx
72 .send(model_response(
73 &request,
74 indoc! {r"
75 --- a/root/1.txt
76 +++ b/root/1.txt
77 @@ ... @@
78 Hello!
79 -How
80 +How are you?
81 Bye
82 "},
83 ))
84 .unwrap();
85
86 cx.run_until_parked();
87
88 ep_store.update(cx, |ep_store, cx| {
89 let prediction = ep_store
90 .prediction_at(&buffer1, None, &project, cx)
91 .unwrap();
92 assert_matches!(prediction, BufferEditPrediction::Local { .. });
93 });
94
95 ep_store.update(cx, |ep_store, cx| {
96 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
97 });
98
99 // Prediction for diagnostic in another file
100
101 let diagnostic = lsp::Diagnostic {
102 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
103 severity: Some(lsp::DiagnosticSeverity::ERROR),
104 message: "Sentence is incomplete".to_string(),
105 ..Default::default()
106 };
107
108 project.update(cx, |project, cx| {
109 project.lsp_store().update(cx, |lsp_store, cx| {
110 lsp_store
111 .update_diagnostics(
112 LanguageServerId(0),
113 lsp::PublishDiagnosticsParams {
114 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
115 diagnostics: vec![diagnostic],
116 version: None,
117 },
118 None,
119 language::DiagnosticSourceKind::Pushed,
120 &[],
121 cx,
122 )
123 .unwrap();
124 });
125 });
126
127 let (request, respond_tx) = requests.predict.next().await.unwrap();
128 respond_tx
129 .send(model_response(
130 &request,
131 indoc! {r#"
132 --- a/root/2.txt
133 +++ b/root/2.txt
134 @@ ... @@
135 Hola!
136 -Como
137 +Como estas?
138 Adios
139 "#},
140 ))
141 .unwrap();
142 cx.run_until_parked();
143
144 ep_store.update(cx, |ep_store, cx| {
145 let prediction = ep_store
146 .prediction_at(&buffer1, None, &project, cx)
147 .unwrap();
148 assert_matches!(
149 prediction,
150 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
151 );
152 });
153
154 let buffer2 = project
155 .update(cx, |project, cx| {
156 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
157 project.open_buffer(path, cx)
158 })
159 .await
160 .unwrap();
161
162 ep_store.update(cx, |ep_store, cx| {
163 let prediction = ep_store
164 .prediction_at(&buffer2, None, &project, cx)
165 .unwrap();
166 assert_matches!(prediction, BufferEditPrediction::Local { .. });
167 });
168}
169
170#[gpui::test]
171async fn test_simple_request(cx: &mut TestAppContext) {
172 let (ep_store, mut requests) = init_test_with_fake_client(cx);
173 let fs = FakeFs::new(cx.executor());
174 fs.insert_tree(
175 "/root",
176 json!({
177 "foo.md": "Hello!\nHow\nBye\n"
178 }),
179 )
180 .await;
181 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
182
183 let buffer = project
184 .update(cx, |project, cx| {
185 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
186 project.open_buffer(path, cx)
187 })
188 .await
189 .unwrap();
190 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
191 let position = snapshot.anchor_before(language::Point::new(1, 3));
192
193 let prediction_task = ep_store.update(cx, |ep_store, cx| {
194 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
195 });
196
197 let (request, respond_tx) = requests.predict.next().await.unwrap();
198
199 // TODO Put back when we have a structured request again
200 // assert_eq!(
201 // request.excerpt_path.as_ref(),
202 // Path::new(path!("root/foo.md"))
203 // );
204 // assert_eq!(
205 // request.cursor_point,
206 // Point {
207 // line: Line(1),
208 // column: 3
209 // }
210 // );
211
212 respond_tx
213 .send(model_response(
214 &request,
215 indoc! { r"
216 --- a/root/foo.md
217 +++ b/root/foo.md
218 @@ ... @@
219 Hello!
220 -How
221 +How are you?
222 Bye
223 "},
224 ))
225 .unwrap();
226
227 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
228
229 assert_eq!(prediction.edits.len(), 1);
230 assert_eq!(
231 prediction.edits[0].0.to_point(&snapshot).start,
232 language::Point::new(1, 3)
233 );
234 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
235}
236
237#[gpui::test]
238async fn test_request_events(cx: &mut TestAppContext) {
239 let (ep_store, mut requests) = init_test_with_fake_client(cx);
240 let fs = FakeFs::new(cx.executor());
241 fs.insert_tree(
242 "/root",
243 json!({
244 "foo.md": "Hello!\n\nBye\n"
245 }),
246 )
247 .await;
248 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
249
250 let buffer = project
251 .update(cx, |project, cx| {
252 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
253 project.open_buffer(path, cx)
254 })
255 .await
256 .unwrap();
257
258 ep_store.update(cx, |ep_store, cx| {
259 ep_store.register_buffer(&buffer, &project, cx);
260 });
261
262 buffer.update(cx, |buffer, cx| {
263 buffer.edit(vec![(7..7, "How")], None, cx);
264 });
265
266 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
267 let position = snapshot.anchor_before(language::Point::new(1, 3));
268
269 let prediction_task = ep_store.update(cx, |ep_store, cx| {
270 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
271 });
272
273 let (request, respond_tx) = requests.predict.next().await.unwrap();
274
275 let prompt = prompt_from_request(&request);
276 assert!(
277 prompt.contains(indoc! {"
278 --- a/root/foo.md
279 +++ b/root/foo.md
280 @@ -1,3 +1,3 @@
281 Hello!
282 -
283 +How
284 Bye
285 "}),
286 "{prompt}"
287 );
288
289 respond_tx
290 .send(model_response(
291 &request,
292 indoc! {r#"
293 --- a/root/foo.md
294 +++ b/root/foo.md
295 @@ ... @@
296 Hello!
297 -How
298 +How are you?
299 Bye
300 "#},
301 ))
302 .unwrap();
303
304 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
305
306 assert_eq!(prediction.edits.len(), 1);
307 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
308}
309
310#[gpui::test]
311async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
312 let (ep_store, _requests) = init_test_with_fake_client(cx);
313 let fs = FakeFs::new(cx.executor());
314 fs.insert_tree(
315 "/root",
316 json!({
317 "foo.md": "Hello!\n\nBye\n"
318 }),
319 )
320 .await;
321 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
322
323 let buffer = project
324 .update(cx, |project, cx| {
325 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
326 project.open_buffer(path, cx)
327 })
328 .await
329 .unwrap();
330
331 ep_store.update(cx, |ep_store, cx| {
332 ep_store.register_buffer(&buffer, &project, cx);
333 });
334
335 // First burst: insert "How"
336 buffer.update(cx, |buffer, cx| {
337 buffer.edit(vec![(7..7, "How")], None, cx);
338 });
339
340 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
341 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
342 cx.run_until_parked();
343
344 // Second burst: append " are you?" immediately after "How" on the same line.
345 //
346 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
347 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
348 buffer.update(cx, |buffer, cx| {
349 buffer.edit(vec![(10..10, " are you?")], None, cx);
350 });
351
352 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
353 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
354 buffer.update(cx, |buffer, cx| {
355 buffer.edit(vec![(19..19, "!")], None, cx);
356 });
357
358 // With time-based splitting, there are two distinct events.
359 let events = ep_store.update(cx, |ep_store, cx| {
360 ep_store.edit_history_for_project(&project, cx)
361 });
362 assert_eq!(events.len(), 2);
363 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
364 assert_eq!(
365 diff.as_str(),
366 indoc! {"
367 @@ -1,3 +1,3 @@
368 Hello!
369 -
370 +How
371 Bye
372 "}
373 );
374
375 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
376 assert_eq!(
377 diff.as_str(),
378 indoc! {"
379 @@ -1,3 +1,3 @@
380 Hello!
381 -How
382 +How are you?!
383 Bye
384 "}
385 );
386}
387
388#[gpui::test]
389async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
390 let (ep_store, _requests) = init_test_with_fake_client(cx);
391 let fs = FakeFs::new(cx.executor());
392
393 // Create a file with 30 lines to test line-based coalescing
394 let content = (1..=30)
395 .map(|i| format!("Line {}\n", i))
396 .collect::<String>();
397 fs.insert_tree(
398 "/root",
399 json!({
400 "foo.md": content
401 }),
402 )
403 .await;
404 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
405
406 let buffer = project
407 .update(cx, |project, cx| {
408 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
409 project.open_buffer(path, cx)
410 })
411 .await
412 .unwrap();
413
414 ep_store.update(cx, |ep_store, cx| {
415 ep_store.register_buffer(&buffer, &project, cx);
416 });
417
418 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
419 buffer.update(cx, |buffer, cx| {
420 let start = Point::new(10, 0).to_offset(buffer);
421 let end = Point::new(13, 0).to_offset(buffer);
422 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
423 });
424
425 let events = ep_store.update(cx, |ep_store, cx| {
426 ep_store.edit_history_for_project(&project, cx)
427 });
428 assert_eq!(
429 render_events(&events),
430 indoc! {"
431 @@ -8,9 +8,8 @@
432 Line 8
433 Line 9
434 Line 10
435 -Line 11
436 -Line 12
437 -Line 13
438 +Middle A
439 +Middle B
440 Line 14
441 Line 15
442 Line 16
443 "},
444 "After first edit"
445 );
446
447 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
448 // This tests that coalescing considers the START of the existing range
449 buffer.update(cx, |buffer, cx| {
450 let offset = Point::new(5, 0).to_offset(buffer);
451 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
452 });
453
454 let events = ep_store.update(cx, |ep_store, cx| {
455 ep_store.edit_history_for_project(&project, cx)
456 });
457 assert_eq!(
458 render_events(&events),
459 indoc! {"
460 @@ -3,14 +3,14 @@
461 Line 3
462 Line 4
463 Line 5
464 +Above
465 Line 6
466 Line 7
467 Line 8
468 Line 9
469 Line 10
470 -Line 11
471 -Line 12
472 -Line 13
473 +Middle A
474 +Middle B
475 Line 14
476 Line 15
477 Line 16
478 "},
479 "After inserting above (should coalesce)"
480 );
481
482 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
483 // This tests that coalescing considers the END of the existing range
484 buffer.update(cx, |buffer, cx| {
485 let offset = Point::new(14, 0).to_offset(buffer);
486 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
487 });
488
489 let events = ep_store.update(cx, |ep_store, cx| {
490 ep_store.edit_history_for_project(&project, cx)
491 });
492 assert_eq!(
493 render_events(&events),
494 indoc! {"
495 @@ -3,15 +3,16 @@
496 Line 3
497 Line 4
498 Line 5
499 +Above
500 Line 6
501 Line 7
502 Line 8
503 Line 9
504 Line 10
505 -Line 11
506 -Line 12
507 -Line 13
508 +Middle A
509 +Middle B
510 Line 14
511 +Below
512 Line 15
513 Line 16
514 Line 17
515 "},
516 "After inserting below (should coalesce)"
517 );
518
519 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
520 // This should NOT coalesce - creates a new event
521 buffer.update(cx, |buffer, cx| {
522 let offset = Point::new(25, 0).to_offset(buffer);
523 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
524 });
525
526 let events = ep_store.update(cx, |ep_store, cx| {
527 ep_store.edit_history_for_project(&project, cx)
528 });
529 assert_eq!(
530 render_events(&events),
531 indoc! {"
532 @@ -3,15 +3,16 @@
533 Line 3
534 Line 4
535 Line 5
536 +Above
537 Line 6
538 Line 7
539 Line 8
540 Line 9
541 Line 10
542 -Line 11
543 -Line 12
544 -Line 13
545 +Middle A
546 +Middle B
547 Line 14
548 +Below
549 Line 15
550 Line 16
551 Line 17
552
553 ---
554 @@ -23,6 +23,7 @@
555 Line 22
556 Line 23
557 Line 24
558 +Far below
559 Line 25
560 Line 26
561 Line 27
562 "},
563 "After inserting far below (should NOT coalesce)"
564 );
565}
566
567fn render_events(events: &[StoredEvent]) -> String {
568 events
569 .iter()
570 .map(|e| {
571 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
572 diff.as_str()
573 })
574 .collect::<Vec<_>>()
575 .join("\n---\n")
576}
577
578fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
579 events
580 .iter()
581 .map(|e| {
582 let zeta_prompt::Event::BufferChange {
583 diff, predicted, ..
584 } = e.event.as_ref();
585 let prefix = if *predicted { "predicted" } else { "manual" };
586 format!("{}\n{}", prefix, diff)
587 })
588 .collect()
589}
590
591#[gpui::test]
592async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
593 let (ep_store, _requests) = init_test_with_fake_client(cx);
594 let fs = FakeFs::new(cx.executor());
595 fs.insert_tree(
596 "/root",
597 json!({
598 "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
599 }),
600 )
601 .await;
602 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
603
604 let buffer = project
605 .update(cx, |project, cx| {
606 let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
607 project.open_buffer(path, cx)
608 })
609 .await
610 .unwrap();
611
612 ep_store.update(cx, |ep_store, cx| {
613 ep_store.register_buffer(&buffer, &project, cx);
614 });
615
616 // Case 1: Manual edits have `predicted` set to false.
617 buffer.update(cx, |buffer, cx| {
618 buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
619 });
620
621 let events = ep_store.update(cx, |ep_store, cx| {
622 ep_store.edit_history_for_project(&project, cx)
623 });
624
625 assert_eq!(
626 render_events_with_predicted(&events),
627 vec![indoc! {"
628 manual
629 @@ -1,4 +1,4 @@
630 -line 0
631 +LINE ZERO
632 line 1
633 line 2
634 line 3
635 "}]
636 );
637
638 // Case 2: Multiple successive manual edits near each other are merged into one
639 // event with `predicted` set to false.
640 buffer.update(cx, |buffer, cx| {
641 let offset = Point::new(1, 0).to_offset(buffer);
642 let end = Point::new(1, 6).to_offset(buffer);
643 buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
644 });
645
646 let events = ep_store.update(cx, |ep_store, cx| {
647 ep_store.edit_history_for_project(&project, cx)
648 });
649 assert_eq!(
650 render_events_with_predicted(&events),
651 vec![indoc! {"
652 manual
653 @@ -1,5 +1,5 @@
654 -line 0
655 -line 1
656 +LINE ZERO
657 +LINE ONE
658 line 2
659 line 3
660 line 4
661 "}]
662 );
663
664 // Case 3: Accepted predictions have `predicted` set to true.
665 // Case 5: A manual edit that follows a predicted edit is not merged with the
666 // predicted edit, even if it is nearby.
667 ep_store.update(cx, |ep_store, cx| {
668 buffer.update(cx, |buffer, cx| {
669 let offset = Point::new(2, 0).to_offset(buffer);
670 let end = Point::new(2, 6).to_offset(buffer);
671 buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
672 });
673 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
674 });
675
676 let events = ep_store.update(cx, |ep_store, cx| {
677 ep_store.edit_history_for_project(&project, cx)
678 });
679 assert_eq!(
680 render_events_with_predicted(&events),
681 vec![
682 indoc! {"
683 manual
684 @@ -1,5 +1,5 @@
685 -line 0
686 -line 1
687 +LINE ZERO
688 +LINE ONE
689 line 2
690 line 3
691 line 4
692 "},
693 indoc! {"
694 predicted
695 @@ -1,6 +1,6 @@
696 LINE ZERO
697 LINE ONE
698 -line 2
699 +LINE TWO
700 line 3
701 line 4
702 line 5
703 "}
704 ]
705 );
706
707 // Case 4: Multiple successive accepted predictions near each other are merged
708 // into one event with `predicted` set to true.
709 ep_store.update(cx, |ep_store, cx| {
710 buffer.update(cx, |buffer, cx| {
711 let offset = Point::new(3, 0).to_offset(buffer);
712 let end = Point::new(3, 6).to_offset(buffer);
713 buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
714 });
715 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
716 });
717
718 let events = ep_store.update(cx, |ep_store, cx| {
719 ep_store.edit_history_for_project(&project, cx)
720 });
721 assert_eq!(
722 render_events_with_predicted(&events),
723 vec![
724 indoc! {"
725 manual
726 @@ -1,5 +1,5 @@
727 -line 0
728 -line 1
729 +LINE ZERO
730 +LINE ONE
731 line 2
732 line 3
733 line 4
734 "},
735 indoc! {"
736 predicted
737 @@ -1,7 +1,7 @@
738 LINE ZERO
739 LINE ONE
740 -line 2
741 -line 3
742 +LINE TWO
743 +LINE THREE
744 line 4
745 line 5
746 line 6
747 "}
748 ]
749 );
750
751 // Case 5 (continued): A manual edit that follows a predicted edit is not merged
752 // with the predicted edit, even if it is nearby.
753 buffer.update(cx, |buffer, cx| {
754 let offset = Point::new(4, 0).to_offset(buffer);
755 let end = Point::new(4, 6).to_offset(buffer);
756 buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
757 });
758
759 let events = ep_store.update(cx, |ep_store, cx| {
760 ep_store.edit_history_for_project(&project, cx)
761 });
762 assert_eq!(
763 render_events_with_predicted(&events),
764 vec![
765 indoc! {"
766 manual
767 @@ -1,5 +1,5 @@
768 -line 0
769 -line 1
770 +LINE ZERO
771 +LINE ONE
772 line 2
773 line 3
774 line 4
775 "},
776 indoc! {"
777 predicted
778 @@ -1,7 +1,7 @@
779 LINE ZERO
780 LINE ONE
781 -line 2
782 -line 3
783 +LINE TWO
784 +LINE THREE
785 line 4
786 line 5
787 line 6
788 "},
789 indoc! {"
790 manual
791 @@ -2,7 +2,7 @@
792 LINE ONE
793 LINE TWO
794 LINE THREE
795 -line 4
796 +LINE FOUR
797 line 5
798 line 6
799 line 7
800 "}
801 ]
802 );
803
804 // Case 6: If we then perform a manual edit at a *different* location (more than
805 // 8 lines away), then the edits at the prior location can be merged with each
806 // other, even if some are predicted and some are not. `predicted` means all
807 // constituent edits were predicted.
808 buffer.update(cx, |buffer, cx| {
809 let offset = Point::new(14, 0).to_offset(buffer);
810 let end = Point::new(14, 7).to_offset(buffer);
811 buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
812 });
813
814 let events = ep_store.update(cx, |ep_store, cx| {
815 ep_store.edit_history_for_project(&project, cx)
816 });
817 assert_eq!(
818 render_events_with_predicted(&events),
819 vec![
820 indoc! {"
821 manual
822 @@ -1,8 +1,8 @@
823 -line 0
824 -line 1
825 -line 2
826 -line 3
827 -line 4
828 +LINE ZERO
829 +LINE ONE
830 +LINE TWO
831 +LINE THREE
832 +LINE FOUR
833 line 5
834 line 6
835 line 7
836 "},
837 indoc! {"
838 manual
839 @@ -12,4 +12,4 @@
840 line 11
841 line 12
842 line 13
843 -line 14
844 +LINE FOURTEEN
845 "}
846 ]
847 );
848}
849
850#[gpui::test]
851async fn test_empty_prediction(cx: &mut TestAppContext) {
852 let (ep_store, mut requests) = init_test_with_fake_client(cx);
853 let fs = FakeFs::new(cx.executor());
854 fs.insert_tree(
855 "/root",
856 json!({
857 "foo.md": "Hello!\nHow\nBye\n"
858 }),
859 )
860 .await;
861 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
862
863 let buffer = project
864 .update(cx, |project, cx| {
865 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
866 project.open_buffer(path, cx)
867 })
868 .await
869 .unwrap();
870 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
871 let position = snapshot.anchor_before(language::Point::new(1, 3));
872
873 ep_store.update(cx, |ep_store, cx| {
874 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
875 });
876
877 let (request, respond_tx) = requests.predict.next().await.unwrap();
878 let response = model_response(&request, "");
879 let id = response.request_id.clone();
880 respond_tx.send(response).unwrap();
881
882 cx.run_until_parked();
883
884 ep_store.update(cx, |ep_store, cx| {
885 assert!(
886 ep_store
887 .prediction_at(&buffer, None, &project, cx)
888 .is_none()
889 );
890 });
891
892 // prediction is reported as rejected
893 let (reject_request, _) = requests.reject.next().await.unwrap();
894
895 assert_eq!(
896 &reject_request.rejections,
897 &[EditPredictionRejection {
898 request_id: id,
899 reason: EditPredictionRejectReason::Empty,
900 was_shown: false
901 }]
902 );
903}
904
905#[gpui::test]
906async fn test_interpolated_empty(cx: &mut TestAppContext) {
907 let (ep_store, mut requests) = init_test_with_fake_client(cx);
908 let fs = FakeFs::new(cx.executor());
909 fs.insert_tree(
910 "/root",
911 json!({
912 "foo.md": "Hello!\nHow\nBye\n"
913 }),
914 )
915 .await;
916 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
917
918 let buffer = project
919 .update(cx, |project, cx| {
920 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
921 project.open_buffer(path, cx)
922 })
923 .await
924 .unwrap();
925 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
926 let position = snapshot.anchor_before(language::Point::new(1, 3));
927
928 ep_store.update(cx, |ep_store, cx| {
929 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
930 });
931
932 let (request, respond_tx) = requests.predict.next().await.unwrap();
933
934 buffer.update(cx, |buffer, cx| {
935 buffer.set_text("Hello!\nHow are you?\nBye", cx);
936 });
937
938 let response = model_response(&request, SIMPLE_DIFF);
939 let id = response.request_id.clone();
940 respond_tx.send(response).unwrap();
941
942 cx.run_until_parked();
943
944 ep_store.update(cx, |ep_store, cx| {
945 assert!(
946 ep_store
947 .prediction_at(&buffer, None, &project, cx)
948 .is_none()
949 );
950 });
951
952 // prediction is reported as rejected
953 let (reject_request, _) = requests.reject.next().await.unwrap();
954
955 assert_eq!(
956 &reject_request.rejections,
957 &[EditPredictionRejection {
958 request_id: id,
959 reason: EditPredictionRejectReason::InterpolatedEmpty,
960 was_shown: false
961 }]
962 );
963}
964
965const SIMPLE_DIFF: &str = indoc! { r"
966 --- a/root/foo.md
967 +++ b/root/foo.md
968 @@ ... @@
969 Hello!
970 -How
971 +How are you?
972 Bye
973"};
974
975#[gpui::test]
976async fn test_replace_current(cx: &mut TestAppContext) {
977 let (ep_store, mut requests) = init_test_with_fake_client(cx);
978 let fs = FakeFs::new(cx.executor());
979 fs.insert_tree(
980 "/root",
981 json!({
982 "foo.md": "Hello!\nHow\nBye\n"
983 }),
984 )
985 .await;
986 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
987
988 let buffer = project
989 .update(cx, |project, cx| {
990 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
991 project.open_buffer(path, cx)
992 })
993 .await
994 .unwrap();
995 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
996 let position = snapshot.anchor_before(language::Point::new(1, 3));
997
998 ep_store.update(cx, |ep_store, cx| {
999 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1000 });
1001
1002 let (request, respond_tx) = requests.predict.next().await.unwrap();
1003 let first_response = model_response(&request, SIMPLE_DIFF);
1004 let first_id = first_response.request_id.clone();
1005 respond_tx.send(first_response).unwrap();
1006
1007 cx.run_until_parked();
1008
1009 ep_store.update(cx, |ep_store, cx| {
1010 assert_eq!(
1011 ep_store
1012 .prediction_at(&buffer, None, &project, cx)
1013 .unwrap()
1014 .id
1015 .0,
1016 first_id
1017 );
1018 });
1019
1020 // a second request is triggered
1021 ep_store.update(cx, |ep_store, cx| {
1022 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1023 });
1024
1025 let (request, respond_tx) = requests.predict.next().await.unwrap();
1026 let second_response = model_response(&request, SIMPLE_DIFF);
1027 let second_id = second_response.request_id.clone();
1028 respond_tx.send(second_response).unwrap();
1029
1030 cx.run_until_parked();
1031
1032 ep_store.update(cx, |ep_store, cx| {
1033 // second replaces first
1034 assert_eq!(
1035 ep_store
1036 .prediction_at(&buffer, None, &project, cx)
1037 .unwrap()
1038 .id
1039 .0,
1040 second_id
1041 );
1042 });
1043
1044 // first is reported as replaced
1045 let (reject_request, _) = requests.reject.next().await.unwrap();
1046
1047 assert_eq!(
1048 &reject_request.rejections,
1049 &[EditPredictionRejection {
1050 request_id: first_id,
1051 reason: EditPredictionRejectReason::Replaced,
1052 was_shown: false
1053 }]
1054 );
1055}
1056
1057#[gpui::test]
1058async fn test_current_preferred(cx: &mut TestAppContext) {
1059 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1060 let fs = FakeFs::new(cx.executor());
1061 fs.insert_tree(
1062 "/root",
1063 json!({
1064 "foo.md": "Hello!\nHow\nBye\n"
1065 }),
1066 )
1067 .await;
1068 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1069
1070 let buffer = project
1071 .update(cx, |project, cx| {
1072 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1073 project.open_buffer(path, cx)
1074 })
1075 .await
1076 .unwrap();
1077 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1078 let position = snapshot.anchor_before(language::Point::new(1, 3));
1079
1080 ep_store.update(cx, |ep_store, cx| {
1081 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1082 });
1083
1084 let (request, respond_tx) = requests.predict.next().await.unwrap();
1085 let first_response = model_response(&request, SIMPLE_DIFF);
1086 let first_id = first_response.request_id.clone();
1087 respond_tx.send(first_response).unwrap();
1088
1089 cx.run_until_parked();
1090
1091 ep_store.update(cx, |ep_store, cx| {
1092 assert_eq!(
1093 ep_store
1094 .prediction_at(&buffer, None, &project, cx)
1095 .unwrap()
1096 .id
1097 .0,
1098 first_id
1099 );
1100 });
1101
1102 // a second request is triggered
1103 ep_store.update(cx, |ep_store, cx| {
1104 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1105 });
1106
1107 let (request, respond_tx) = requests.predict.next().await.unwrap();
1108 // worse than current prediction
1109 let second_response = model_response(
1110 &request,
1111 indoc! { r"
1112 --- a/root/foo.md
1113 +++ b/root/foo.md
1114 @@ ... @@
1115 Hello!
1116 -How
1117 +How are
1118 Bye
1119 "},
1120 );
1121 let second_id = second_response.request_id.clone();
1122 respond_tx.send(second_response).unwrap();
1123
1124 cx.run_until_parked();
1125
1126 ep_store.update(cx, |ep_store, cx| {
1127 // first is preferred over second
1128 assert_eq!(
1129 ep_store
1130 .prediction_at(&buffer, None, &project, cx)
1131 .unwrap()
1132 .id
1133 .0,
1134 first_id
1135 );
1136 });
1137
1138 // second is reported as rejected
1139 let (reject_request, _) = requests.reject.next().await.unwrap();
1140
1141 assert_eq!(
1142 &reject_request.rejections,
1143 &[EditPredictionRejection {
1144 request_id: second_id,
1145 reason: EditPredictionRejectReason::CurrentPreferred,
1146 was_shown: false
1147 }]
1148 );
1149}
1150
1151#[gpui::test]
1152async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1153 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1154 let fs = FakeFs::new(cx.executor());
1155 fs.insert_tree(
1156 "/root",
1157 json!({
1158 "foo.md": "Hello!\nHow\nBye\n"
1159 }),
1160 )
1161 .await;
1162 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1163
1164 let buffer = project
1165 .update(cx, |project, cx| {
1166 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1167 project.open_buffer(path, cx)
1168 })
1169 .await
1170 .unwrap();
1171 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1172 let position = snapshot.anchor_before(language::Point::new(1, 3));
1173
1174 // start two refresh tasks
1175 ep_store.update(cx, |ep_store, cx| {
1176 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1177 });
1178
1179 let (request1, respond_first) = requests.predict.next().await.unwrap();
1180
1181 ep_store.update(cx, |ep_store, cx| {
1182 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1183 });
1184
1185 let (request, respond_second) = requests.predict.next().await.unwrap();
1186
1187 // wait for throttle
1188 cx.run_until_parked();
1189
1190 // second responds first
1191 let second_response = model_response(&request, SIMPLE_DIFF);
1192 let second_id = second_response.request_id.clone();
1193 respond_second.send(second_response).unwrap();
1194
1195 cx.run_until_parked();
1196
1197 ep_store.update(cx, |ep_store, cx| {
1198 // current prediction is second
1199 assert_eq!(
1200 ep_store
1201 .prediction_at(&buffer, None, &project, cx)
1202 .unwrap()
1203 .id
1204 .0,
1205 second_id
1206 );
1207 });
1208
1209 let first_response = model_response(&request1, SIMPLE_DIFF);
1210 let first_id = first_response.request_id.clone();
1211 respond_first.send(first_response).unwrap();
1212
1213 cx.run_until_parked();
1214
1215 ep_store.update(cx, |ep_store, cx| {
1216 // current prediction is still second, since first was cancelled
1217 assert_eq!(
1218 ep_store
1219 .prediction_at(&buffer, None, &project, cx)
1220 .unwrap()
1221 .id
1222 .0,
1223 second_id
1224 );
1225 });
1226
1227 // first is reported as rejected
1228 let (reject_request, _) = requests.reject.next().await.unwrap();
1229
1230 cx.run_until_parked();
1231
1232 assert_eq!(
1233 &reject_request.rejections,
1234 &[EditPredictionRejection {
1235 request_id: first_id,
1236 reason: EditPredictionRejectReason::Canceled,
1237 was_shown: false
1238 }]
1239 );
1240}
1241
1242#[gpui::test]
1243async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1244 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1245 let fs = FakeFs::new(cx.executor());
1246 fs.insert_tree(
1247 "/root",
1248 json!({
1249 "foo.md": "Hello!\nHow\nBye\n"
1250 }),
1251 )
1252 .await;
1253 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1254
1255 let buffer = project
1256 .update(cx, |project, cx| {
1257 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1258 project.open_buffer(path, cx)
1259 })
1260 .await
1261 .unwrap();
1262 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1263 let position = snapshot.anchor_before(language::Point::new(1, 3));
1264
1265 // start two refresh tasks
1266 ep_store.update(cx, |ep_store, cx| {
1267 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1268 });
1269
1270 let (request1, respond_first) = requests.predict.next().await.unwrap();
1271
1272 ep_store.update(cx, |ep_store, cx| {
1273 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1274 });
1275
1276 let (request2, respond_second) = requests.predict.next().await.unwrap();
1277
1278 // wait for throttle, so requests are sent
1279 cx.run_until_parked();
1280
1281 ep_store.update(cx, |ep_store, cx| {
1282 // start a third request
1283 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1284
1285 // 2 are pending, so 2nd is cancelled
1286 assert_eq!(
1287 ep_store
1288 .get_or_init_project(&project, cx)
1289 .cancelled_predictions
1290 .iter()
1291 .copied()
1292 .collect::<Vec<_>>(),
1293 [1]
1294 );
1295 });
1296
1297 // wait for throttle
1298 cx.run_until_parked();
1299
1300 let (request3, respond_third) = requests.predict.next().await.unwrap();
1301
1302 let first_response = model_response(&request1, SIMPLE_DIFF);
1303 let first_id = first_response.request_id.clone();
1304 respond_first.send(first_response).unwrap();
1305
1306 cx.run_until_parked();
1307
1308 ep_store.update(cx, |ep_store, cx| {
1309 // current prediction is first
1310 assert_eq!(
1311 ep_store
1312 .prediction_at(&buffer, None, &project, cx)
1313 .unwrap()
1314 .id
1315 .0,
1316 first_id
1317 );
1318 });
1319
1320 let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1321 let cancelled_id = cancelled_response.request_id.clone();
1322 respond_second.send(cancelled_response).unwrap();
1323
1324 cx.run_until_parked();
1325
1326 ep_store.update(cx, |ep_store, cx| {
1327 // current prediction is still first, since second was cancelled
1328 assert_eq!(
1329 ep_store
1330 .prediction_at(&buffer, None, &project, cx)
1331 .unwrap()
1332 .id
1333 .0,
1334 first_id
1335 );
1336 });
1337
1338 let third_response = model_response(&request3, SIMPLE_DIFF);
1339 let third_response_id = third_response.request_id.clone();
1340 respond_third.send(third_response).unwrap();
1341
1342 cx.run_until_parked();
1343
1344 ep_store.update(cx, |ep_store, cx| {
1345 // third completes and replaces first
1346 assert_eq!(
1347 ep_store
1348 .prediction_at(&buffer, None, &project, cx)
1349 .unwrap()
1350 .id
1351 .0,
1352 third_response_id
1353 );
1354 });
1355
1356 // second is reported as rejected
1357 let (reject_request, _) = requests.reject.next().await.unwrap();
1358
1359 cx.run_until_parked();
1360
1361 assert_eq!(
1362 &reject_request.rejections,
1363 &[
1364 EditPredictionRejection {
1365 request_id: cancelled_id,
1366 reason: EditPredictionRejectReason::Canceled,
1367 was_shown: false
1368 },
1369 EditPredictionRejection {
1370 request_id: first_id,
1371 reason: EditPredictionRejectReason::Replaced,
1372 was_shown: false
1373 }
1374 ]
1375 );
1376}
1377
1378#[gpui::test]
1379async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
1380 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1381
1382 let fs = FakeFs::new(cx.executor());
1383 fs.insert_tree(
1384 "/root",
1385 json!({
1386 "foo.md": "Hello!\nHow\nBye\n",
1387 "bar.md": "Hola!\nComo\nAdios\n"
1388 }),
1389 )
1390 .await;
1391 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1392
1393 let buffer = project
1394 .update(cx, |project, cx| {
1395 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1396 project.set_active_path(Some(path.clone()), cx);
1397 project.open_buffer(path, cx)
1398 })
1399 .await
1400 .unwrap();
1401 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1402 let position = snapshot.anchor_before(language::Point::new(1, 3));
1403
1404 ep_store.update(cx, |ep_store, cx| {
1405 ep_store.register_project(&project, cx);
1406 ep_store.register_buffer(&buffer, &project, cx);
1407 });
1408
1409 // First edit request - no prior edit, so not throttled.
1410 ep_store.update(cx, |ep_store, cx| {
1411 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1412 });
1413 let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
1414 edit_response_tx.send(empty_response()).unwrap();
1415 cx.run_until_parked();
1416
1417 let diagnostic = lsp::Diagnostic {
1418 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1419 severity: Some(lsp::DiagnosticSeverity::ERROR),
1420 message: "Sentence is incomplete".to_string(),
1421 ..Default::default()
1422 };
1423
1424 // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
1425 project.update(cx, |project, cx| {
1426 project.lsp_store().update(cx, |lsp_store, cx| {
1427 lsp_store
1428 .update_diagnostics(
1429 LanguageServerId(0),
1430 lsp::PublishDiagnosticsParams {
1431 uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
1432 diagnostics: vec![diagnostic],
1433 version: None,
1434 },
1435 None,
1436 language::DiagnosticSourceKind::Pushed,
1437 &[],
1438 cx,
1439 )
1440 .unwrap();
1441 });
1442 });
1443 let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
1444 jump_response_tx.send(empty_response()).unwrap();
1445 cx.run_until_parked();
1446
1447 // Second edit request - should be throttled by the first edit.
1448 ep_store.update(cx, |ep_store, cx| {
1449 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1450 });
1451 assert_no_predict_request_ready(&mut requests.predict);
1452
1453 // Second jump request - should be throttled by the first jump.
1454 ep_store.update(cx, |ep_store, cx| {
1455 ep_store.refresh_prediction_from_diagnostics(
1456 project.clone(),
1457 DiagnosticSearchScope::Global,
1458 cx,
1459 );
1460 });
1461 assert_no_predict_request_ready(&mut requests.predict);
1462
1463 // Wait for both throttles to expire.
1464 cx.background_executor
1465 .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
1466 cx.background_executor.run_until_parked();
1467 cx.run_until_parked();
1468
1469 // Both requests should now go through.
1470 let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
1471 response_tx_1.send(empty_response()).unwrap();
1472 cx.run_until_parked();
1473
1474 let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
1475 response_tx_2.send(empty_response()).unwrap();
1476 cx.run_until_parked();
1477}
1478
1479#[gpui::test]
1480async fn test_rejections_flushing(cx: &mut TestAppContext) {
1481 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1482
1483 ep_store.update(cx, |ep_store, cx| {
1484 ep_store.reject_prediction(
1485 EditPredictionId("test-1".into()),
1486 EditPredictionRejectReason::Discarded,
1487 false,
1488 cx,
1489 );
1490 ep_store.reject_prediction(
1491 EditPredictionId("test-2".into()),
1492 EditPredictionRejectReason::Canceled,
1493 true,
1494 cx,
1495 );
1496 });
1497
1498 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1499 cx.run_until_parked();
1500
1501 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1502 respond_tx.send(()).unwrap();
1503
1504 // batched
1505 assert_eq!(reject_request.rejections.len(), 2);
1506 assert_eq!(
1507 reject_request.rejections[0],
1508 EditPredictionRejection {
1509 request_id: "test-1".to_string(),
1510 reason: EditPredictionRejectReason::Discarded,
1511 was_shown: false
1512 }
1513 );
1514 assert_eq!(
1515 reject_request.rejections[1],
1516 EditPredictionRejection {
1517 request_id: "test-2".to_string(),
1518 reason: EditPredictionRejectReason::Canceled,
1519 was_shown: true
1520 }
1521 );
1522
1523 // Reaching batch size limit sends without debounce
1524 ep_store.update(cx, |ep_store, cx| {
1525 for i in 0..70 {
1526 ep_store.reject_prediction(
1527 EditPredictionId(format!("batch-{}", i).into()),
1528 EditPredictionRejectReason::Discarded,
1529 false,
1530 cx,
1531 );
1532 }
1533 });
1534
1535 // First MAX/2 items are sent immediately
1536 cx.run_until_parked();
1537 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1538 respond_tx.send(()).unwrap();
1539
1540 assert_eq!(reject_request.rejections.len(), 50);
1541 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1542 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1543
1544 // Remaining items are debounced with the next batch
1545 cx.executor().advance_clock(Duration::from_secs(15));
1546 cx.run_until_parked();
1547
1548 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1549 respond_tx.send(()).unwrap();
1550
1551 assert_eq!(reject_request.rejections.len(), 20);
1552 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1553 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1554
1555 // Request failure
1556 ep_store.update(cx, |ep_store, cx| {
1557 ep_store.reject_prediction(
1558 EditPredictionId("retry-1".into()),
1559 EditPredictionRejectReason::Discarded,
1560 false,
1561 cx,
1562 );
1563 });
1564
1565 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1566 cx.run_until_parked();
1567
1568 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1569 assert_eq!(reject_request.rejections.len(), 1);
1570 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1571 // Simulate failure
1572 drop(_respond_tx);
1573
1574 // Add another rejection
1575 ep_store.update(cx, |ep_store, cx| {
1576 ep_store.reject_prediction(
1577 EditPredictionId("retry-2".into()),
1578 EditPredictionRejectReason::Discarded,
1579 false,
1580 cx,
1581 );
1582 });
1583
1584 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1585 cx.run_until_parked();
1586
1587 // Retry should include both the failed item and the new one
1588 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1589 respond_tx.send(()).unwrap();
1590
1591 assert_eq!(reject_request.rejections.len(), 2);
1592 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1593 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1594}
1595
1596// Skipped until we start including diagnostics in prompt
1597// #[gpui::test]
1598// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1599// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1600// let fs = FakeFs::new(cx.executor());
1601// fs.insert_tree(
1602// "/root",
1603// json!({
1604// "foo.md": "Hello!\nBye"
1605// }),
1606// )
1607// .await;
1608// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1609
1610// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1611// let diagnostic = lsp::Diagnostic {
1612// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1613// severity: Some(lsp::DiagnosticSeverity::ERROR),
1614// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1615// ..Default::default()
1616// };
1617
1618// project.update(cx, |project, cx| {
1619// project.lsp_store().update(cx, |lsp_store, cx| {
1620// // Create some diagnostics
1621// lsp_store
1622// .update_diagnostics(
1623// LanguageServerId(0),
1624// lsp::PublishDiagnosticsParams {
1625// uri: path_to_buffer_uri.clone(),
1626// diagnostics: vec![diagnostic],
1627// version: None,
1628// },
1629// None,
1630// language::DiagnosticSourceKind::Pushed,
1631// &[],
1632// cx,
1633// )
1634// .unwrap();
1635// });
1636// });
1637
1638// let buffer = project
1639// .update(cx, |project, cx| {
1640// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1641// project.open_buffer(path, cx)
1642// })
1643// .await
1644// .unwrap();
1645
1646// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1647// let position = snapshot.anchor_before(language::Point::new(0, 0));
1648
1649// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1650// ep_store.request_prediction(&project, &buffer, position, cx)
1651// });
1652
1653// let (request, _respond_tx) = req_rx.next().await.unwrap();
1654
1655// assert_eq!(request.diagnostic_groups.len(), 1);
1656// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1657// .unwrap();
1658// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1659// assert_eq!(
1660// value,
1661// json!({
1662// "entries": [{
1663// "range": {
1664// "start": 8,
1665// "end": 10
1666// },
1667// "diagnostic": {
1668// "source": null,
1669// "code": null,
1670// "code_description": null,
1671// "severity": 1,
1672// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1673// "markdown": null,
1674// "group_id": 0,
1675// "is_primary": true,
1676// "is_disk_based": false,
1677// "is_unnecessary": false,
1678// "source_kind": "Pushed",
1679// "data": null,
1680// "underline": true
1681// }
1682// }],
1683// "primary_ix": 0
1684// })
1685// );
1686// }
1687
1688// Generate a model response that would apply the given diff to the active file.
1689fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1690 let editable_range = request
1691 .input
1692 .excerpt_ranges
1693 .as_ref()
1694 .map(|r| zeta_prompt::excerpt_range_for_format(Default::default(), r).1)
1695 .unwrap_or(request.input.editable_range_in_excerpt.clone());
1696 let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string();
1697 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1698
1699 PredictEditsV3Response {
1700 request_id: Uuid::new_v4().to_string(),
1701 editable_range,
1702 output: new_excerpt,
1703 }
1704}
1705
1706fn empty_response() -> PredictEditsV3Response {
1707 PredictEditsV3Response {
1708 request_id: Uuid::new_v4().to_string(),
1709 editable_range: 0..0,
1710 output: String::new(),
1711 }
1712}
1713
1714fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1715 zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1716}
1717
1718fn assert_no_predict_request_ready(
1719 requests: &mut mpsc::UnboundedReceiver<(
1720 PredictEditsV3Request,
1721 oneshot::Sender<PredictEditsV3Response>,
1722 )>,
1723) {
1724 if requests.next().now_or_never().flatten().is_some() {
1725 panic!("Unexpected prediction request while throttled.");
1726 }
1727}
1728
1729struct RequestChannels {
1730 predict: mpsc::UnboundedReceiver<(
1731 PredictEditsV3Request,
1732 oneshot::Sender<PredictEditsV3Response>,
1733 )>,
1734 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1735}
1736
1737fn init_test_with_fake_client(
1738 cx: &mut TestAppContext,
1739) -> (Entity<EditPredictionStore>, RequestChannels) {
1740 cx.update(move |cx| {
1741 let settings_store = SettingsStore::test(cx);
1742 cx.set_global(settings_store);
1743 zlog::init_test();
1744
1745 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1746 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1747
1748 let http_client = FakeHttpClient::create({
1749 move |req| {
1750 let uri = req.uri().path().to_string();
1751 let mut body = req.into_body();
1752 let predict_req_tx = predict_req_tx.clone();
1753 let reject_req_tx = reject_req_tx.clone();
1754 async move {
1755 let resp = match uri.as_str() {
1756 "/client/llm_tokens" => serde_json::to_string(&json!({
1757 "token": "test"
1758 }))
1759 .unwrap(),
1760 "/predict_edits/v3" => {
1761 let mut buf = Vec::new();
1762 body.read_to_end(&mut buf).await.ok();
1763 let decompressed = zstd::decode_all(&buf[..]).unwrap();
1764 let req = serde_json::from_slice(&decompressed).unwrap();
1765
1766 let (res_tx, res_rx) = oneshot::channel();
1767 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1768 serde_json::to_string(&res_rx.await?).unwrap()
1769 }
1770 "/predict_edits/reject" => {
1771 let mut buf = Vec::new();
1772 body.read_to_end(&mut buf).await.ok();
1773 let req = serde_json::from_slice(&buf).unwrap();
1774
1775 let (res_tx, res_rx) = oneshot::channel();
1776 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1777 serde_json::to_string(&res_rx.await?).unwrap()
1778 }
1779 _ => {
1780 panic!("Unexpected path: {}", uri)
1781 }
1782 };
1783
1784 Ok(Response::builder().body(resp.into()).unwrap())
1785 }
1786 }
1787 });
1788
1789 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1790 client.cloud_client().set_credentials(1, "test".into());
1791
1792 language_model::init(client.clone(), cx);
1793
1794 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1795 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1796
1797 (
1798 ep_store,
1799 RequestChannels {
1800 predict: predict_req_rx,
1801 reject: reject_req_rx,
1802 },
1803 )
1804 })
1805}
1806
1807#[gpui::test]
1808async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1809 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1810 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1811 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1812 });
1813
1814 let edit_preview = cx
1815 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1816 .await;
1817
1818 let prediction = EditPrediction {
1819 edits,
1820 cursor_position: None,
1821 edit_preview,
1822 buffer: buffer.clone(),
1823 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1824 id: EditPredictionId("the-id".into()),
1825 inputs: ZetaPromptInput {
1826 events: Default::default(),
1827 related_files: Default::default(),
1828 cursor_path: Path::new("").into(),
1829 cursor_excerpt: "".into(),
1830 editable_range_in_excerpt: 0..0,
1831 cursor_offset_in_excerpt: 0,
1832 excerpt_start_row: None,
1833 excerpt_ranges: None,
1834 preferred_model: None,
1835 in_open_source_repo: false,
1836 can_collect_data: false,
1837 },
1838 buffer_snapshotted_at: Instant::now(),
1839 response_received_at: Instant::now(),
1840 };
1841
1842 cx.update(|cx| {
1843 assert_eq!(
1844 from_completion_edits(
1845 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1846 &buffer,
1847 cx
1848 ),
1849 vec![(2..5, "REM".into()), (9..11, "".into())]
1850 );
1851
1852 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1853 assert_eq!(
1854 from_completion_edits(
1855 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1856 &buffer,
1857 cx
1858 ),
1859 vec![(2..2, "REM".into()), (6..8, "".into())]
1860 );
1861
1862 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1863 assert_eq!(
1864 from_completion_edits(
1865 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1866 &buffer,
1867 cx
1868 ),
1869 vec![(2..5, "REM".into()), (9..11, "".into())]
1870 );
1871
1872 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1873 assert_eq!(
1874 from_completion_edits(
1875 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1876 &buffer,
1877 cx
1878 ),
1879 vec![(3..3, "EM".into()), (7..9, "".into())]
1880 );
1881
1882 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1883 assert_eq!(
1884 from_completion_edits(
1885 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1886 &buffer,
1887 cx
1888 ),
1889 vec![(4..4, "M".into()), (8..10, "".into())]
1890 );
1891
1892 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1893 assert_eq!(
1894 from_completion_edits(
1895 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1896 &buffer,
1897 cx
1898 ),
1899 vec![(9..11, "".into())]
1900 );
1901
1902 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1903 assert_eq!(
1904 from_completion_edits(
1905 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1906 &buffer,
1907 cx
1908 ),
1909 vec![(4..4, "M".into()), (8..10, "".into())]
1910 );
1911
1912 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1913 assert_eq!(
1914 from_completion_edits(
1915 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1916 &buffer,
1917 cx
1918 ),
1919 vec![(4..4, "M".into())]
1920 );
1921
1922 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1923 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1924 })
1925}
1926
1927#[gpui::test]
1928async fn test_clean_up_diff(cx: &mut TestAppContext) {
1929 init_test(cx);
1930
1931 assert_eq!(
1932 apply_edit_prediction(
1933 indoc! {"
1934 fn main() {
1935 let word_1 = \"lorem\";
1936 let range = word.len()..word.len();
1937 }
1938 "},
1939 indoc! {"
1940 fn main() {
1941 let word_1 = \"lorem\";
1942 let range = word_1.len()..word_1.len();
1943 }
1944 "},
1945 cx,
1946 )
1947 .await,
1948 indoc! {"
1949 fn main() {
1950 let word_1 = \"lorem\";
1951 let range = word_1.len()..word_1.len();
1952 }
1953 "},
1954 );
1955
1956 assert_eq!(
1957 apply_edit_prediction(
1958 indoc! {"
1959 fn main() {
1960 let story = \"the quick\"
1961 }
1962 "},
1963 indoc! {"
1964 fn main() {
1965 let story = \"the quick brown fox jumps over the lazy dog\";
1966 }
1967 "},
1968 cx,
1969 )
1970 .await,
1971 indoc! {"
1972 fn main() {
1973 let story = \"the quick brown fox jumps over the lazy dog\";
1974 }
1975 "},
1976 );
1977}
1978
1979#[gpui::test]
1980async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1981 init_test(cx);
1982
1983 let buffer_content = "lorem\n";
1984 let completion_response = "lorem\nipsum\n";
1985
1986 assert_eq!(
1987 apply_edit_prediction(buffer_content, completion_response, cx).await,
1988 "lorem\nipsum\n"
1989 );
1990}
1991
1992#[gpui::test]
1993async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1994 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1995 // When the buffer ends without a trailing newline, but the model returns output
1996 // with a trailing newline, zeta2 should normalize both sides before diffing
1997 // so no spurious newline is inserted.
1998 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1999 let fs = FakeFs::new(cx.executor());
2000
2001 // Single line buffer with no trailing newline
2002 fs.insert_tree(
2003 "/root",
2004 json!({
2005 "foo.txt": "hello"
2006 }),
2007 )
2008 .await;
2009 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2010
2011 let buffer = project
2012 .update(cx, |project, cx| {
2013 let path = project
2014 .find_project_path(path!("root/foo.txt"), cx)
2015 .unwrap();
2016 project.open_buffer(path, cx)
2017 })
2018 .await
2019 .unwrap();
2020
2021 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2022 let position = snapshot.anchor_before(language::Point::new(0, 5));
2023
2024 ep_store.update(cx, |ep_store, cx| {
2025 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2026 });
2027
2028 let (request, respond_tx) = requests.predict.next().await.unwrap();
2029
2030 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
2031 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
2032 let excerpt_length = request.input.cursor_excerpt.len();
2033 let response = PredictEditsV3Response {
2034 request_id: Uuid::new_v4().to_string(),
2035 output: "hello world\n".to_string(),
2036 editable_range: 0..excerpt_length,
2037 };
2038 respond_tx.send(response).unwrap();
2039
2040 cx.run_until_parked();
2041
2042 // The prediction should insert " world" without adding a newline
2043 ep_store.update(cx, |ep_store, cx| {
2044 let prediction = ep_store
2045 .prediction_at(&buffer, None, &project, cx)
2046 .expect("should have prediction");
2047 let edits: Vec<_> = prediction
2048 .edits
2049 .iter()
2050 .map(|(range, text)| {
2051 let snapshot = buffer.read(cx).snapshot();
2052 (range.to_offset(&snapshot), text.clone())
2053 })
2054 .collect();
2055 assert_eq!(edits, vec![(5..5, " world".into())]);
2056 });
2057}
2058
2059fn init_test(cx: &mut TestAppContext) {
2060 cx.update(|cx| {
2061 let settings_store = SettingsStore::test(cx);
2062 cx.set_global(settings_store);
2063 });
2064}
2065
2066async fn apply_edit_prediction(
2067 buffer_content: &str,
2068 completion_response: &str,
2069 cx: &mut TestAppContext,
2070) -> String {
2071 let fs = project::FakeFs::new(cx.executor());
2072 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2073 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2074 let (ep_store, response) = make_test_ep_store(&project, cx).await;
2075 *response.lock() = completion_response.to_string();
2076 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2077 buffer.update(cx, |buffer, cx| {
2078 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2079 });
2080 buffer.read_with(cx, |buffer, _| buffer.text())
2081}
2082
2083async fn run_edit_prediction(
2084 buffer: &Entity<Buffer>,
2085 project: &Entity<Project>,
2086 ep_store: &Entity<EditPredictionStore>,
2087 cx: &mut TestAppContext,
2088) -> EditPrediction {
2089 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2090 ep_store.update(cx, |ep_store, cx| {
2091 ep_store.register_buffer(buffer, &project, cx)
2092 });
2093 cx.background_executor.run_until_parked();
2094 let prediction_task = ep_store.update(cx, |ep_store, cx| {
2095 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2096 });
2097 prediction_task.await.unwrap().unwrap().prediction.unwrap()
2098}
2099
2100async fn make_test_ep_store(
2101 project: &Entity<Project>,
2102 cx: &mut TestAppContext,
2103) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
2104 let default_response = "hello world\n".to_string();
2105 let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
2106 let http_client = FakeHttpClient::create({
2107 let completion_response = completion_response.clone();
2108 let mut next_request_id = 0;
2109 move |req| {
2110 let completion_response = completion_response.clone();
2111 let method = req.method().clone();
2112 let uri = req.uri().path().to_string();
2113 let mut body = req.into_body();
2114 async move {
2115 match (method, uri.as_str()) {
2116 (Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2117 .status(200)
2118 .body(
2119 serde_json::to_string(&CreateLlmTokenResponse {
2120 token: LlmToken("the-llm-token".to_string()),
2121 })
2122 .unwrap()
2123 .into(),
2124 )
2125 .unwrap()),
2126 (Method::POST, "/predict_edits/v3") => {
2127 let mut buf = Vec::new();
2128 body.read_to_end(&mut buf).await.ok();
2129 let decompressed = zstd::decode_all(&buf[..]).unwrap();
2130 let req: PredictEditsV3Request =
2131 serde_json::from_slice(&decompressed).unwrap();
2132
2133 next_request_id += 1;
2134 Ok(http_client::Response::builder()
2135 .status(200)
2136 .body(
2137 serde_json::to_string(&PredictEditsV3Response {
2138 request_id: format!("request-{next_request_id}"),
2139 editable_range: 0..req.input.cursor_excerpt.len(),
2140 output: completion_response.lock().clone(),
2141 })
2142 .unwrap()
2143 .into(),
2144 )
2145 .unwrap())
2146 }
2147 _ => Ok(http_client::Response::builder()
2148 .status(404)
2149 .body("Not Found".to_string().into())
2150 .unwrap()),
2151 }
2152 }
2153 }
2154 });
2155
2156 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2157 cx.update(|cx| {
2158 RefreshLlmTokenListener::register(client.clone(), cx);
2159 });
2160 let _server = FakeServer::for_client(42, &client, cx).await;
2161
2162 let ep_store = cx.new(|cx| {
2163 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2164 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2165
2166 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2167 for worktree in worktrees {
2168 let worktree_id = worktree.read(cx).id();
2169 ep_store
2170 .get_or_init_project(project, cx)
2171 .license_detection_watchers
2172 .entry(worktree_id)
2173 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2174 }
2175
2176 ep_store
2177 });
2178
2179 (ep_store, completion_response)
2180}
2181
2182fn to_completion_edits(
2183 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2184 buffer: &Entity<Buffer>,
2185 cx: &App,
2186) -> Vec<(Range<Anchor>, Arc<str>)> {
2187 let buffer = buffer.read(cx);
2188 iterator
2189 .into_iter()
2190 .map(|(range, text)| {
2191 (
2192 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2193 text,
2194 )
2195 })
2196 .collect()
2197}
2198
2199fn from_completion_edits(
2200 editor_edits: &[(Range<Anchor>, Arc<str>)],
2201 buffer: &Entity<Buffer>,
2202 cx: &App,
2203) -> Vec<(Range<usize>, Arc<str>)> {
2204 let buffer = buffer.read(cx);
2205 editor_edits
2206 .iter()
2207 .map(|(range, text)| {
2208 (
2209 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2210 text.clone(),
2211 )
2212 })
2213 .collect()
2214}
2215
2216#[gpui::test]
2217async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2218 init_test(cx);
2219
2220 let fs = FakeFs::new(cx.executor());
2221 fs.insert_tree(
2222 "/project",
2223 serde_json::json!({
2224 "main.rs": "fn main() {\n \n}\n"
2225 }),
2226 )
2227 .await;
2228
2229 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2230
2231 let http_client = FakeHttpClient::create(|_req| async move {
2232 Ok(gpui::http_client::Response::builder()
2233 .status(401)
2234 .body("Unauthorized".into())
2235 .unwrap())
2236 });
2237
2238 let client =
2239 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2240 cx.update(|cx| {
2241 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2242 });
2243
2244 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2245
2246 let buffer = project
2247 .update(cx, |project, cx| {
2248 let path = project
2249 .find_project_path(path!("/project/main.rs"), cx)
2250 .unwrap();
2251 project.open_buffer(path, cx)
2252 })
2253 .await
2254 .unwrap();
2255
2256 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2257 ep_store.update(cx, |ep_store, cx| {
2258 ep_store.register_buffer(&buffer, &project, cx)
2259 });
2260 cx.background_executor.run_until_parked();
2261
2262 let completion_task = ep_store.update(cx, |ep_store, cx| {
2263 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2264 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2265 });
2266
2267 let result = completion_task.await;
2268 assert!(
2269 result.is_err(),
2270 "Without authentication and without custom URL, prediction should fail"
2271 );
2272}
2273
2274#[gpui::test]
2275fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2276 let buffer = cx.new(|cx| {
2277 Buffer::local(
2278 indoc! {"
2279 zero
2280 one
2281 two
2282 three
2283 four
2284 five
2285 six
2286 seven
2287 eight
2288 nine
2289 ten
2290 eleven
2291 twelve
2292 thirteen
2293 fourteen
2294 fifteen
2295 sixteen
2296 seventeen
2297 eighteen
2298 nineteen
2299 twenty
2300 twenty-one
2301 twenty-two
2302 twenty-three
2303 twenty-four
2304 "},
2305 cx,
2306 )
2307 });
2308
2309 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2310
2311 buffer.update(cx, |buffer, cx| {
2312 let point = Point::new(12, 0);
2313 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2314 let point = Point::new(8, 0);
2315 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2316 });
2317
2318 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2319
2320 let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2321
2322 assert_eq!(
2323 diff,
2324 indoc! {"
2325 @@ -6,10 +6,12 @@
2326 five
2327 six
2328 seven
2329 +FIRST INSERTION
2330 eight
2331 nine
2332 ten
2333 eleven
2334 +SECOND INSERTION
2335 twelve
2336 thirteen
2337 fourteen
2338 "}
2339 );
2340}
2341
2342#[gpui::test]
2343async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
2344 fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
2345 let collab_replica = clock::ReplicaId::new(10);
2346 let anchor = buffer.read_with(cx, |buffer, _| {
2347 buffer.snapshot().anchor_before(Point::new(row, 0))
2348 });
2349 let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
2350 id: 1,
2351 start: anchor,
2352 end: anchor,
2353 reversed: false,
2354 goal: SelectionGoal::None,
2355 }]);
2356 buffer.update(cx, |buffer, cx| {
2357 buffer.apply_ops(
2358 [Operation::UpdateSelections {
2359 selections,
2360 lamport_timestamp: clock::Lamport {
2361 replica_id: collab_replica,
2362 value: 1,
2363 },
2364 line_mode: false,
2365 cursor_shape: CursorShape::Bar,
2366 }],
2367 cx,
2368 );
2369 });
2370 }
2371
2372 fn publish_diagnostics(
2373 uri_path: &'static str,
2374 rows: &[u32],
2375 project: &Entity<Project>,
2376 cx: &mut TestAppContext,
2377 ) {
2378 let diagnostics: Vec<_> = rows
2379 .iter()
2380 .map(|&row| lsp::Diagnostic {
2381 range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
2382 severity: Some(lsp::DiagnosticSeverity::ERROR),
2383 message: format!("error at row {row}"),
2384 ..Default::default()
2385 })
2386 .collect();
2387 project.update(cx, |project, cx| {
2388 project.lsp_store().update(cx, |lsp_store, cx| {
2389 lsp_store
2390 .update_diagnostics(
2391 LanguageServerId(0),
2392 lsp::PublishDiagnosticsParams {
2393 uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
2394 diagnostics,
2395 version: None,
2396 },
2397 None,
2398 language::DiagnosticSourceKind::Pushed,
2399 &[],
2400 cx,
2401 )
2402 .expect("failed to update diagnostics");
2403 });
2404 });
2405 }
2406
2407 init_test(cx);
2408
2409 let mut lines = String::new();
2410 for i in 0..60 {
2411 lines.push_str(&format!("line {i}\n"));
2412 }
2413
2414 let fs = FakeFs::new(cx.executor());
2415 fs.insert_tree(
2416 "/root",
2417 json!({
2418 "active.txt": lines,
2419 "collab_file.txt": "error here\nsecond line\n",
2420 "free_file.txt": "another error\nsecond line\n",
2421 }),
2422 )
2423 .await;
2424 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2425
2426 let active_buffer = project
2427 .update(cx, |project, cx| {
2428 let path = project
2429 .find_project_path(path!("/root/active.txt"), cx)
2430 .expect("active.txt not found");
2431 project.set_active_path(Some(path.clone()), cx);
2432 project.open_buffer(path, cx)
2433 })
2434 .await
2435 .expect("failed to open active buffer");
2436
2437 set_collaborator_cursor(&active_buffer, 5, cx);
2438
2439 publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
2440
2441 cx.run_until_parked();
2442
2443 let cursor_point = Point::new(25, 0);
2444 let empty_search_range: Range<Point> = Default::default();
2445
2446 let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2447 let result = EditPredictionStore::next_diagnostic_location(
2448 active_buffer.clone(),
2449 &snapshot,
2450 empty_search_range.clone(),
2451 cursor_point,
2452 &project,
2453 &mut cx.to_async(),
2454 )
2455 .await
2456 .expect("next_diagnostic_location failed");
2457
2458 let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
2459 assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
2460 let result_row = result_buffer.read_with(cx, |buffer, _| {
2461 result_anchor.to_point(&buffer.snapshot()).row
2462 });
2463 assert_ne!(
2464 result_row, 3,
2465 "row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
2466 );
2467 assert!(
2468 result_row == 25 || result_row == 50,
2469 "expected row 25 or 50, got {result_row}"
2470 );
2471
2472 let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2473 let near_cursor_point = Point::new(4, 0);
2474 let result_near = EditPredictionStore::next_diagnostic_location(
2475 active_buffer.clone(),
2476 &snapshot_near,
2477 empty_search_range.clone(),
2478 near_cursor_point,
2479 &project,
2480 &mut cx.to_async(),
2481 )
2482 .await
2483 .expect("next_diagnostic_location failed");
2484
2485 let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
2486 let near_row =
2487 active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
2488 assert_eq!(
2489 near_row, 3,
2490 "row 3 should be included when local cursor (row 4) is also near the collaborator"
2491 );
2492
2493 let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2494 let far_cursor_point = Point::new(50, 0);
2495 let result_far = EditPredictionStore::next_diagnostic_location(
2496 active_buffer.clone(),
2497 &snapshot_far,
2498 empty_search_range.clone(),
2499 far_cursor_point,
2500 &project,
2501 &mut cx.to_async(),
2502 )
2503 .await
2504 .expect("next_diagnostic_location failed");
2505
2506 let (_, far_anchor) = result_far.expect("expected a diagnostic location");
2507 let far_row =
2508 active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
2509 assert_eq!(
2510 far_row, 50,
2511 "row 50 is near local cursor (row 50) and far from collaborator, should be picked"
2512 );
2513
2514 publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
2515 publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
2516 cx.run_until_parked();
2517
2518 let collab_buffer = project
2519 .update(cx, |project, cx| {
2520 let path = project
2521 .find_project_path(path!("/root/collab_file.txt"), cx)
2522 .expect("collab_file.txt not found");
2523 project.open_buffer(path, cx)
2524 })
2525 .await
2526 .expect("failed to open collab buffer");
2527
2528 set_collaborator_cursor(&collab_buffer, 0, cx);
2529 cx.run_until_parked();
2530
2531 let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
2532 let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2533 let result_cross = EditPredictionStore::next_diagnostic_location(
2534 active_buffer.clone(),
2535 &snapshot_cross,
2536 no_same_file_search_range,
2537 Point::new(0, 0),
2538 &project,
2539 &mut cx.to_async(),
2540 )
2541 .await
2542 .expect("cross-file next_diagnostic_location failed");
2543
2544 let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
2545 let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
2546 buffer
2547 .file()
2548 .expect("buffer should have a file")
2549 .full_path(cx)
2550 });
2551 assert_eq!(
2552 cross_path,
2553 Path::new(path!("root/free_file.txt")),
2554 "should skip collab_file.txt (has collaborator) and pick free_file.txt"
2555 );
2556}
2557
2558#[ctor::ctor]
2559fn init_logger() {
2560 zlog::init_test();
2561}