1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
3use client::{UserStore, test::FakeServer};
4use clock::FakeSystemClock;
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
8 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
9};
10use futures::{
11 AsyncReadExt, FutureExt, StreamExt,
12 channel::{mpsc, oneshot},
13};
14use gpui::App;
15use gpui::{
16 Entity, TestAppContext,
17 http_client::{FakeHttpClient, Response},
18};
19use indoc::indoc;
20use language::{Anchor, Buffer, CursorShape, Operation, Point, Selection, SelectionGoal};
21use lsp::LanguageServerId;
22use parking_lot::Mutex;
23use pretty_assertions::{assert_eq, assert_matches};
24use project::{FakeFs, Project};
25use serde_json::json;
26use settings::SettingsStore;
27use std::{path::Path, sync::Arc, time::Duration};
28use util::path;
29use uuid::Uuid;
30use zeta_prompt::ZetaPromptInput;
31
32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
33
34#[gpui::test]
35async fn test_current_state(cx: &mut TestAppContext) {
36 let (ep_store, mut requests) = init_test_with_fake_client(cx);
37 let fs = FakeFs::new(cx.executor());
38 fs.insert_tree(
39 "/root",
40 json!({
41 "1.txt": "Hello!\nHow\nBye\n",
42 "2.txt": "Hola!\nComo\nAdios\n"
43 }),
44 )
45 .await;
46 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
47
48 let buffer1 = project
49 .update(cx, |project, cx| {
50 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
51 project.set_active_path(Some(path.clone()), cx);
52 project.open_buffer(path, cx)
53 })
54 .await
55 .unwrap();
56 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
57 let position = snapshot1.anchor_before(language::Point::new(1, 3));
58
59 ep_store.update(cx, |ep_store, cx| {
60 ep_store.register_project(&project, cx);
61 ep_store.register_buffer(&buffer1, &project, cx);
62 });
63
64 // Prediction for current file
65
66 ep_store.update(cx, |ep_store, cx| {
67 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
68 });
69 let (request, respond_tx) = requests.predict.next().await.unwrap();
70
71 respond_tx
72 .send(model_response(
73 &request,
74 indoc! {r"
75 --- a/root/1.txt
76 +++ b/root/1.txt
77 @@ ... @@
78 Hello!
79 -How
80 +How are you?
81 Bye
82 "},
83 ))
84 .unwrap();
85
86 cx.run_until_parked();
87
88 ep_store.update(cx, |ep_store, cx| {
89 let prediction = ep_store
90 .prediction_at(&buffer1, None, &project, cx)
91 .unwrap();
92 assert_matches!(prediction, BufferEditPrediction::Local { .. });
93 });
94
95 ep_store.update(cx, |ep_store, cx| {
96 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
97 });
98
99 // Prediction for diagnostic in another file
100
101 let diagnostic = lsp::Diagnostic {
102 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
103 severity: Some(lsp::DiagnosticSeverity::ERROR),
104 message: "Sentence is incomplete".to_string(),
105 ..Default::default()
106 };
107
108 project.update(cx, |project, cx| {
109 project.lsp_store().update(cx, |lsp_store, cx| {
110 lsp_store
111 .update_diagnostics(
112 LanguageServerId(0),
113 lsp::PublishDiagnosticsParams {
114 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
115 diagnostics: vec![diagnostic],
116 version: None,
117 },
118 None,
119 language::DiagnosticSourceKind::Pushed,
120 &[],
121 cx,
122 )
123 .unwrap();
124 });
125 });
126
127 let (request, respond_tx) = requests.predict.next().await.unwrap();
128 respond_tx
129 .send(model_response(
130 &request,
131 indoc! {r#"
132 --- a/root/2.txt
133 +++ b/root/2.txt
134 @@ ... @@
135 Hola!
136 -Como
137 +Como estas?
138 Adios
139 "#},
140 ))
141 .unwrap();
142 cx.run_until_parked();
143
144 ep_store.update(cx, |ep_store, cx| {
145 let prediction = ep_store
146 .prediction_at(&buffer1, None, &project, cx)
147 .unwrap();
148 assert_matches!(
149 prediction,
150 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
151 );
152 });
153
154 let buffer2 = project
155 .update(cx, |project, cx| {
156 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
157 project.open_buffer(path, cx)
158 })
159 .await
160 .unwrap();
161
162 ep_store.update(cx, |ep_store, cx| {
163 let prediction = ep_store
164 .prediction_at(&buffer2, None, &project, cx)
165 .unwrap();
166 assert_matches!(prediction, BufferEditPrediction::Local { .. });
167 });
168}
169
170#[gpui::test]
171async fn test_simple_request(cx: &mut TestAppContext) {
172 let (ep_store, mut requests) = init_test_with_fake_client(cx);
173 let fs = FakeFs::new(cx.executor());
174 fs.insert_tree(
175 "/root",
176 json!({
177 "foo.md": "Hello!\nHow\nBye\n"
178 }),
179 )
180 .await;
181 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
182
183 let buffer = project
184 .update(cx, |project, cx| {
185 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
186 project.open_buffer(path, cx)
187 })
188 .await
189 .unwrap();
190 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
191 let position = snapshot.anchor_before(language::Point::new(1, 3));
192
193 let prediction_task = ep_store.update(cx, |ep_store, cx| {
194 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
195 });
196
197 let (request, respond_tx) = requests.predict.next().await.unwrap();
198
199 // TODO Put back when we have a structured request again
200 // assert_eq!(
201 // request.excerpt_path.as_ref(),
202 // Path::new(path!("root/foo.md"))
203 // );
204 // assert_eq!(
205 // request.cursor_point,
206 // Point {
207 // line: Line(1),
208 // column: 3
209 // }
210 // );
211
212 respond_tx
213 .send(model_response(
214 &request,
215 indoc! { r"
216 --- a/root/foo.md
217 +++ b/root/foo.md
218 @@ ... @@
219 Hello!
220 -How
221 +How are you?
222 Bye
223 "},
224 ))
225 .unwrap();
226
227 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
228
229 assert_eq!(prediction.edits.len(), 1);
230 assert_eq!(
231 prediction.edits[0].0.to_point(&snapshot).start,
232 language::Point::new(1, 3)
233 );
234 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
235}
236
237#[gpui::test]
238async fn test_request_events(cx: &mut TestAppContext) {
239 let (ep_store, mut requests) = init_test_with_fake_client(cx);
240 let fs = FakeFs::new(cx.executor());
241 fs.insert_tree(
242 "/root",
243 json!({
244 "foo.md": "Hello!\n\nBye\n"
245 }),
246 )
247 .await;
248 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
249
250 let buffer = project
251 .update(cx, |project, cx| {
252 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
253 project.open_buffer(path, cx)
254 })
255 .await
256 .unwrap();
257
258 ep_store.update(cx, |ep_store, cx| {
259 ep_store.register_buffer(&buffer, &project, cx);
260 });
261
262 buffer.update(cx, |buffer, cx| {
263 buffer.edit(vec![(7..7, "How")], None, cx);
264 });
265
266 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
267 let position = snapshot.anchor_before(language::Point::new(1, 3));
268
269 let prediction_task = ep_store.update(cx, |ep_store, cx| {
270 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
271 });
272
273 let (request, respond_tx) = requests.predict.next().await.unwrap();
274
275 let prompt = prompt_from_request(&request);
276 assert!(
277 prompt.contains(indoc! {"
278 --- a/root/foo.md
279 +++ b/root/foo.md
280 @@ -1,3 +1,3 @@
281 Hello!
282 -
283 +How
284 Bye
285 "}),
286 "{prompt}"
287 );
288
289 respond_tx
290 .send(model_response(
291 &request,
292 indoc! {r#"
293 --- a/root/foo.md
294 +++ b/root/foo.md
295 @@ ... @@
296 Hello!
297 -How
298 +How are you?
299 Bye
300 "#},
301 ))
302 .unwrap();
303
304 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
305
306 assert_eq!(prediction.edits.len(), 1);
307 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
308}
309
310#[gpui::test]
311async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
312 let (ep_store, _requests) = init_test_with_fake_client(cx);
313 let fs = FakeFs::new(cx.executor());
314 fs.insert_tree(
315 "/root",
316 json!({
317 "foo.md": "Hello!\n\nBye\n"
318 }),
319 )
320 .await;
321 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
322
323 let buffer = project
324 .update(cx, |project, cx| {
325 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
326 project.open_buffer(path, cx)
327 })
328 .await
329 .unwrap();
330
331 ep_store.update(cx, |ep_store, cx| {
332 ep_store.register_buffer(&buffer, &project, cx);
333 });
334
335 // First burst: insert "How"
336 buffer.update(cx, |buffer, cx| {
337 buffer.edit(vec![(7..7, "How")], None, cx);
338 });
339
340 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
341 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
342 cx.run_until_parked();
343
344 // Second burst: append " are you?" immediately after "How" on the same line.
345 //
346 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
347 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
348 buffer.update(cx, |buffer, cx| {
349 buffer.edit(vec![(10..10, " are you?")], None, cx);
350 });
351
352 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
353 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
354 buffer.update(cx, |buffer, cx| {
355 buffer.edit(vec![(19..19, "!")], None, cx);
356 });
357
358 // With time-based splitting, there are two distinct events.
359 let events = ep_store.update(cx, |ep_store, cx| {
360 ep_store.edit_history_for_project(&project, cx)
361 });
362 assert_eq!(events.len(), 2);
363 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
364 assert_eq!(
365 diff.as_str(),
366 indoc! {"
367 @@ -1,3 +1,3 @@
368 Hello!
369 -
370 +How
371 Bye
372 "}
373 );
374
375 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
376 assert_eq!(
377 diff.as_str(),
378 indoc! {"
379 @@ -1,3 +1,3 @@
380 Hello!
381 -How
382 +How are you?!
383 Bye
384 "}
385 );
386}
387
388#[gpui::test]
389async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
390 let (ep_store, _requests) = init_test_with_fake_client(cx);
391 let fs = FakeFs::new(cx.executor());
392
393 // Create a file with 30 lines to test line-based coalescing
394 let content = (1..=30)
395 .map(|i| format!("Line {}\n", i))
396 .collect::<String>();
397 fs.insert_tree(
398 "/root",
399 json!({
400 "foo.md": content
401 }),
402 )
403 .await;
404 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
405
406 let buffer = project
407 .update(cx, |project, cx| {
408 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
409 project.open_buffer(path, cx)
410 })
411 .await
412 .unwrap();
413
414 ep_store.update(cx, |ep_store, cx| {
415 ep_store.register_buffer(&buffer, &project, cx);
416 });
417
418 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
419 buffer.update(cx, |buffer, cx| {
420 let start = Point::new(10, 0).to_offset(buffer);
421 let end = Point::new(13, 0).to_offset(buffer);
422 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
423 });
424
425 let events = ep_store.update(cx, |ep_store, cx| {
426 ep_store.edit_history_for_project(&project, cx)
427 });
428 assert_eq!(
429 render_events(&events),
430 indoc! {"
431 @@ -8,9 +8,8 @@
432 Line 8
433 Line 9
434 Line 10
435 -Line 11
436 -Line 12
437 -Line 13
438 +Middle A
439 +Middle B
440 Line 14
441 Line 15
442 Line 16
443 "},
444 "After first edit"
445 );
446
447 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
448 // This tests that coalescing considers the START of the existing range
449 buffer.update(cx, |buffer, cx| {
450 let offset = Point::new(5, 0).to_offset(buffer);
451 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
452 });
453
454 let events = ep_store.update(cx, |ep_store, cx| {
455 ep_store.edit_history_for_project(&project, cx)
456 });
457 assert_eq!(
458 render_events(&events),
459 indoc! {"
460 @@ -3,14 +3,14 @@
461 Line 3
462 Line 4
463 Line 5
464 +Above
465 Line 6
466 Line 7
467 Line 8
468 Line 9
469 Line 10
470 -Line 11
471 -Line 12
472 -Line 13
473 +Middle A
474 +Middle B
475 Line 14
476 Line 15
477 Line 16
478 "},
479 "After inserting above (should coalesce)"
480 );
481
482 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
483 // This tests that coalescing considers the END of the existing range
484 buffer.update(cx, |buffer, cx| {
485 let offset = Point::new(14, 0).to_offset(buffer);
486 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
487 });
488
489 let events = ep_store.update(cx, |ep_store, cx| {
490 ep_store.edit_history_for_project(&project, cx)
491 });
492 assert_eq!(
493 render_events(&events),
494 indoc! {"
495 @@ -3,15 +3,16 @@
496 Line 3
497 Line 4
498 Line 5
499 +Above
500 Line 6
501 Line 7
502 Line 8
503 Line 9
504 Line 10
505 -Line 11
506 -Line 12
507 -Line 13
508 +Middle A
509 +Middle B
510 Line 14
511 +Below
512 Line 15
513 Line 16
514 Line 17
515 "},
516 "After inserting below (should coalesce)"
517 );
518
519 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
520 // This should NOT coalesce - creates a new event
521 buffer.update(cx, |buffer, cx| {
522 let offset = Point::new(25, 0).to_offset(buffer);
523 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
524 });
525
526 let events = ep_store.update(cx, |ep_store, cx| {
527 ep_store.edit_history_for_project(&project, cx)
528 });
529 assert_eq!(
530 render_events(&events),
531 indoc! {"
532 @@ -3,15 +3,16 @@
533 Line 3
534 Line 4
535 Line 5
536 +Above
537 Line 6
538 Line 7
539 Line 8
540 Line 9
541 Line 10
542 -Line 11
543 -Line 12
544 -Line 13
545 +Middle A
546 +Middle B
547 Line 14
548 +Below
549 Line 15
550 Line 16
551 Line 17
552
553 ---
554 @@ -23,6 +23,7 @@
555 Line 22
556 Line 23
557 Line 24
558 +Far below
559 Line 25
560 Line 26
561 Line 27
562 "},
563 "After inserting far below (should NOT coalesce)"
564 );
565}
566
567fn render_events(events: &[StoredEvent]) -> String {
568 events
569 .iter()
570 .map(|e| {
571 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
572 diff.as_str()
573 })
574 .collect::<Vec<_>>()
575 .join("\n---\n")
576}
577
578fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
579 events
580 .iter()
581 .map(|e| {
582 let zeta_prompt::Event::BufferChange {
583 diff, predicted, ..
584 } = e.event.as_ref();
585 let prefix = if *predicted { "predicted" } else { "manual" };
586 format!("{}\n{}", prefix, diff)
587 })
588 .collect()
589}
590
591#[gpui::test]
592async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
593 let (ep_store, _requests) = init_test_with_fake_client(cx);
594 let fs = FakeFs::new(cx.executor());
595 fs.insert_tree(
596 "/root",
597 json!({
598 "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
599 }),
600 )
601 .await;
602 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
603
604 let buffer = project
605 .update(cx, |project, cx| {
606 let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
607 project.open_buffer(path, cx)
608 })
609 .await
610 .unwrap();
611
612 ep_store.update(cx, |ep_store, cx| {
613 ep_store.register_buffer(&buffer, &project, cx);
614 });
615
616 // Case 1: Manual edits have `predicted` set to false.
617 buffer.update(cx, |buffer, cx| {
618 buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
619 });
620
621 let events = ep_store.update(cx, |ep_store, cx| {
622 ep_store.edit_history_for_project(&project, cx)
623 });
624
625 assert_eq!(
626 render_events_with_predicted(&events),
627 vec![indoc! {"
628 manual
629 @@ -1,4 +1,4 @@
630 -line 0
631 +LINE ZERO
632 line 1
633 line 2
634 line 3
635 "}]
636 );
637
638 // Case 2: Multiple successive manual edits near each other are merged into one
639 // event with `predicted` set to false.
640 buffer.update(cx, |buffer, cx| {
641 let offset = Point::new(1, 0).to_offset(buffer);
642 let end = Point::new(1, 6).to_offset(buffer);
643 buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
644 });
645
646 let events = ep_store.update(cx, |ep_store, cx| {
647 ep_store.edit_history_for_project(&project, cx)
648 });
649 assert_eq!(
650 render_events_with_predicted(&events),
651 vec![indoc! {"
652 manual
653 @@ -1,5 +1,5 @@
654 -line 0
655 -line 1
656 +LINE ZERO
657 +LINE ONE
658 line 2
659 line 3
660 line 4
661 "}]
662 );
663
664 // Case 3: Accepted predictions have `predicted` set to true.
665 // Case 5: A manual edit that follows a predicted edit is not merged with the
666 // predicted edit, even if it is nearby.
667 ep_store.update(cx, |ep_store, cx| {
668 buffer.update(cx, |buffer, cx| {
669 let offset = Point::new(2, 0).to_offset(buffer);
670 let end = Point::new(2, 6).to_offset(buffer);
671 buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
672 });
673 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
674 });
675
676 let events = ep_store.update(cx, |ep_store, cx| {
677 ep_store.edit_history_for_project(&project, cx)
678 });
679 assert_eq!(
680 render_events_with_predicted(&events),
681 vec![
682 indoc! {"
683 manual
684 @@ -1,5 +1,5 @@
685 -line 0
686 -line 1
687 +LINE ZERO
688 +LINE ONE
689 line 2
690 line 3
691 line 4
692 "},
693 indoc! {"
694 predicted
695 @@ -1,6 +1,6 @@
696 LINE ZERO
697 LINE ONE
698 -line 2
699 +LINE TWO
700 line 3
701 line 4
702 line 5
703 "}
704 ]
705 );
706
707 // Case 4: Multiple successive accepted predictions near each other are merged
708 // into one event with `predicted` set to true.
709 ep_store.update(cx, |ep_store, cx| {
710 buffer.update(cx, |buffer, cx| {
711 let offset = Point::new(3, 0).to_offset(buffer);
712 let end = Point::new(3, 6).to_offset(buffer);
713 buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
714 });
715 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
716 });
717
718 let events = ep_store.update(cx, |ep_store, cx| {
719 ep_store.edit_history_for_project(&project, cx)
720 });
721 assert_eq!(
722 render_events_with_predicted(&events),
723 vec![
724 indoc! {"
725 manual
726 @@ -1,5 +1,5 @@
727 -line 0
728 -line 1
729 +LINE ZERO
730 +LINE ONE
731 line 2
732 line 3
733 line 4
734 "},
735 indoc! {"
736 predicted
737 @@ -1,7 +1,7 @@
738 LINE ZERO
739 LINE ONE
740 -line 2
741 -line 3
742 +LINE TWO
743 +LINE THREE
744 line 4
745 line 5
746 line 6
747 "}
748 ]
749 );
750
751 // Case 5 (continued): A manual edit that follows a predicted edit is not merged
752 // with the predicted edit, even if it is nearby.
753 buffer.update(cx, |buffer, cx| {
754 let offset = Point::new(4, 0).to_offset(buffer);
755 let end = Point::new(4, 6).to_offset(buffer);
756 buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
757 });
758
759 let events = ep_store.update(cx, |ep_store, cx| {
760 ep_store.edit_history_for_project(&project, cx)
761 });
762 assert_eq!(
763 render_events_with_predicted(&events),
764 vec![
765 indoc! {"
766 manual
767 @@ -1,5 +1,5 @@
768 -line 0
769 -line 1
770 +LINE ZERO
771 +LINE ONE
772 line 2
773 line 3
774 line 4
775 "},
776 indoc! {"
777 predicted
778 @@ -1,7 +1,7 @@
779 LINE ZERO
780 LINE ONE
781 -line 2
782 -line 3
783 +LINE TWO
784 +LINE THREE
785 line 4
786 line 5
787 line 6
788 "},
789 indoc! {"
790 manual
791 @@ -2,7 +2,7 @@
792 LINE ONE
793 LINE TWO
794 LINE THREE
795 -line 4
796 +LINE FOUR
797 line 5
798 line 6
799 line 7
800 "}
801 ]
802 );
803
804 // Case 6: If we then perform a manual edit at a *different* location (more than
805 // 8 lines away), then the edits at the prior location can be merged with each
806 // other, even if some are predicted and some are not. `predicted` means all
807 // constituent edits were predicted.
808 buffer.update(cx, |buffer, cx| {
809 let offset = Point::new(14, 0).to_offset(buffer);
810 let end = Point::new(14, 7).to_offset(buffer);
811 buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
812 });
813
814 let events = ep_store.update(cx, |ep_store, cx| {
815 ep_store.edit_history_for_project(&project, cx)
816 });
817 assert_eq!(
818 render_events_with_predicted(&events),
819 vec![
820 indoc! {"
821 manual
822 @@ -1,8 +1,8 @@
823 -line 0
824 -line 1
825 -line 2
826 -line 3
827 -line 4
828 +LINE ZERO
829 +LINE ONE
830 +LINE TWO
831 +LINE THREE
832 +LINE FOUR
833 line 5
834 line 6
835 line 7
836 "},
837 indoc! {"
838 manual
839 @@ -12,4 +12,4 @@
840 line 11
841 line 12
842 line 13
843 -line 14
844 +LINE FOURTEEN
845 "}
846 ]
847 );
848}
849
850#[gpui::test]
851async fn test_empty_prediction(cx: &mut TestAppContext) {
852 let (ep_store, mut requests) = init_test_with_fake_client(cx);
853 let fs = FakeFs::new(cx.executor());
854 fs.insert_tree(
855 "/root",
856 json!({
857 "foo.md": "Hello!\nHow\nBye\n"
858 }),
859 )
860 .await;
861 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
862
863 let buffer = project
864 .update(cx, |project, cx| {
865 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
866 project.open_buffer(path, cx)
867 })
868 .await
869 .unwrap();
870 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
871 let position = snapshot.anchor_before(language::Point::new(1, 3));
872
873 ep_store.update(cx, |ep_store, cx| {
874 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
875 });
876
877 let (request, respond_tx) = requests.predict.next().await.unwrap();
878 let response = model_response(&request, "");
879 let id = response.request_id.clone();
880 respond_tx.send(response).unwrap();
881
882 cx.run_until_parked();
883
884 ep_store.update(cx, |ep_store, cx| {
885 assert!(
886 ep_store
887 .prediction_at(&buffer, None, &project, cx)
888 .is_none()
889 );
890 });
891
892 // prediction is reported as rejected
893 let (reject_request, _) = requests.reject.next().await.unwrap();
894
895 assert_eq!(
896 &reject_request.rejections,
897 &[EditPredictionRejection {
898 request_id: id,
899 reason: EditPredictionRejectReason::Empty,
900 was_shown: false
901 }]
902 );
903}
904
905#[gpui::test]
906async fn test_interpolated_empty(cx: &mut TestAppContext) {
907 let (ep_store, mut requests) = init_test_with_fake_client(cx);
908 let fs = FakeFs::new(cx.executor());
909 fs.insert_tree(
910 "/root",
911 json!({
912 "foo.md": "Hello!\nHow\nBye\n"
913 }),
914 )
915 .await;
916 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
917
918 let buffer = project
919 .update(cx, |project, cx| {
920 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
921 project.open_buffer(path, cx)
922 })
923 .await
924 .unwrap();
925 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
926 let position = snapshot.anchor_before(language::Point::new(1, 3));
927
928 ep_store.update(cx, |ep_store, cx| {
929 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
930 });
931
932 let (request, respond_tx) = requests.predict.next().await.unwrap();
933
934 buffer.update(cx, |buffer, cx| {
935 buffer.set_text("Hello!\nHow are you?\nBye", cx);
936 });
937
938 let response = model_response(&request, SIMPLE_DIFF);
939 let id = response.request_id.clone();
940 respond_tx.send(response).unwrap();
941
942 cx.run_until_parked();
943
944 ep_store.update(cx, |ep_store, cx| {
945 assert!(
946 ep_store
947 .prediction_at(&buffer, None, &project, cx)
948 .is_none()
949 );
950 });
951
952 // prediction is reported as rejected
953 let (reject_request, _) = requests.reject.next().await.unwrap();
954
955 assert_eq!(
956 &reject_request.rejections,
957 &[EditPredictionRejection {
958 request_id: id,
959 reason: EditPredictionRejectReason::InterpolatedEmpty,
960 was_shown: false
961 }]
962 );
963}
964
965const SIMPLE_DIFF: &str = indoc! { r"
966 --- a/root/foo.md
967 +++ b/root/foo.md
968 @@ ... @@
969 Hello!
970 -How
971 +How are you?
972 Bye
973"};
974
975#[gpui::test]
976async fn test_replace_current(cx: &mut TestAppContext) {
977 let (ep_store, mut requests) = init_test_with_fake_client(cx);
978 let fs = FakeFs::new(cx.executor());
979 fs.insert_tree(
980 "/root",
981 json!({
982 "foo.md": "Hello!\nHow\nBye\n"
983 }),
984 )
985 .await;
986 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
987
988 let buffer = project
989 .update(cx, |project, cx| {
990 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
991 project.open_buffer(path, cx)
992 })
993 .await
994 .unwrap();
995 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
996 let position = snapshot.anchor_before(language::Point::new(1, 3));
997
998 ep_store.update(cx, |ep_store, cx| {
999 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1000 });
1001
1002 let (request, respond_tx) = requests.predict.next().await.unwrap();
1003 let first_response = model_response(&request, SIMPLE_DIFF);
1004 let first_id = first_response.request_id.clone();
1005 respond_tx.send(first_response).unwrap();
1006
1007 cx.run_until_parked();
1008
1009 ep_store.update(cx, |ep_store, cx| {
1010 assert_eq!(
1011 ep_store
1012 .prediction_at(&buffer, None, &project, cx)
1013 .unwrap()
1014 .id
1015 .0,
1016 first_id
1017 );
1018 });
1019
1020 // a second request is triggered
1021 ep_store.update(cx, |ep_store, cx| {
1022 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1023 });
1024
1025 let (request, respond_tx) = requests.predict.next().await.unwrap();
1026 let second_response = model_response(&request, SIMPLE_DIFF);
1027 let second_id = second_response.request_id.clone();
1028 respond_tx.send(second_response).unwrap();
1029
1030 cx.run_until_parked();
1031
1032 ep_store.update(cx, |ep_store, cx| {
1033 // second replaces first
1034 assert_eq!(
1035 ep_store
1036 .prediction_at(&buffer, None, &project, cx)
1037 .unwrap()
1038 .id
1039 .0,
1040 second_id
1041 );
1042 });
1043
1044 // first is reported as replaced
1045 let (reject_request, _) = requests.reject.next().await.unwrap();
1046
1047 assert_eq!(
1048 &reject_request.rejections,
1049 &[EditPredictionRejection {
1050 request_id: first_id,
1051 reason: EditPredictionRejectReason::Replaced,
1052 was_shown: false
1053 }]
1054 );
1055}
1056
1057#[gpui::test]
1058async fn test_current_preferred(cx: &mut TestAppContext) {
1059 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1060 let fs = FakeFs::new(cx.executor());
1061 fs.insert_tree(
1062 "/root",
1063 json!({
1064 "foo.md": "Hello!\nHow\nBye\n"
1065 }),
1066 )
1067 .await;
1068 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1069
1070 let buffer = project
1071 .update(cx, |project, cx| {
1072 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1073 project.open_buffer(path, cx)
1074 })
1075 .await
1076 .unwrap();
1077 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1078 let position = snapshot.anchor_before(language::Point::new(1, 3));
1079
1080 ep_store.update(cx, |ep_store, cx| {
1081 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1082 });
1083
1084 let (request, respond_tx) = requests.predict.next().await.unwrap();
1085 let first_response = model_response(&request, SIMPLE_DIFF);
1086 let first_id = first_response.request_id.clone();
1087 respond_tx.send(first_response).unwrap();
1088
1089 cx.run_until_parked();
1090
1091 ep_store.update(cx, |ep_store, cx| {
1092 assert_eq!(
1093 ep_store
1094 .prediction_at(&buffer, None, &project, cx)
1095 .unwrap()
1096 .id
1097 .0,
1098 first_id
1099 );
1100 });
1101
1102 // a second request is triggered
1103 ep_store.update(cx, |ep_store, cx| {
1104 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1105 });
1106
1107 let (request, respond_tx) = requests.predict.next().await.unwrap();
1108 // worse than current prediction
1109 let second_response = model_response(
1110 &request,
1111 indoc! { r"
1112 --- a/root/foo.md
1113 +++ b/root/foo.md
1114 @@ ... @@
1115 Hello!
1116 -How
1117 +How are
1118 Bye
1119 "},
1120 );
1121 let second_id = second_response.request_id.clone();
1122 respond_tx.send(second_response).unwrap();
1123
1124 cx.run_until_parked();
1125
1126 ep_store.update(cx, |ep_store, cx| {
1127 // first is preferred over second
1128 assert_eq!(
1129 ep_store
1130 .prediction_at(&buffer, None, &project, cx)
1131 .unwrap()
1132 .id
1133 .0,
1134 first_id
1135 );
1136 });
1137
1138 // second is reported as rejected
1139 let (reject_request, _) = requests.reject.next().await.unwrap();
1140
1141 assert_eq!(
1142 &reject_request.rejections,
1143 &[EditPredictionRejection {
1144 request_id: second_id,
1145 reason: EditPredictionRejectReason::CurrentPreferred,
1146 was_shown: false
1147 }]
1148 );
1149}
1150
1151#[gpui::test]
1152async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1153 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1154 let fs = FakeFs::new(cx.executor());
1155 fs.insert_tree(
1156 "/root",
1157 json!({
1158 "foo.md": "Hello!\nHow\nBye\n"
1159 }),
1160 )
1161 .await;
1162 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1163
1164 let buffer = project
1165 .update(cx, |project, cx| {
1166 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1167 project.open_buffer(path, cx)
1168 })
1169 .await
1170 .unwrap();
1171 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1172 let position = snapshot.anchor_before(language::Point::new(1, 3));
1173
1174 // start two refresh tasks
1175 ep_store.update(cx, |ep_store, cx| {
1176 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1177 });
1178
1179 let (request1, respond_first) = requests.predict.next().await.unwrap();
1180
1181 ep_store.update(cx, |ep_store, cx| {
1182 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1183 });
1184
1185 let (request, respond_second) = requests.predict.next().await.unwrap();
1186
1187 // wait for throttle
1188 cx.run_until_parked();
1189
1190 // second responds first
1191 let second_response = model_response(&request, SIMPLE_DIFF);
1192 let second_id = second_response.request_id.clone();
1193 respond_second.send(second_response).unwrap();
1194
1195 cx.run_until_parked();
1196
1197 ep_store.update(cx, |ep_store, cx| {
1198 // current prediction is second
1199 assert_eq!(
1200 ep_store
1201 .prediction_at(&buffer, None, &project, cx)
1202 .unwrap()
1203 .id
1204 .0,
1205 second_id
1206 );
1207 });
1208
1209 let first_response = model_response(&request1, SIMPLE_DIFF);
1210 let first_id = first_response.request_id.clone();
1211 respond_first.send(first_response).unwrap();
1212
1213 cx.run_until_parked();
1214
1215 ep_store.update(cx, |ep_store, cx| {
1216 // current prediction is still second, since first was cancelled
1217 assert_eq!(
1218 ep_store
1219 .prediction_at(&buffer, None, &project, cx)
1220 .unwrap()
1221 .id
1222 .0,
1223 second_id
1224 );
1225 });
1226
1227 // first is reported as rejected
1228 let (reject_request, _) = requests.reject.next().await.unwrap();
1229
1230 cx.run_until_parked();
1231
1232 assert_eq!(
1233 &reject_request.rejections,
1234 &[EditPredictionRejection {
1235 request_id: first_id,
1236 reason: EditPredictionRejectReason::Canceled,
1237 was_shown: false
1238 }]
1239 );
1240}
1241
1242#[gpui::test]
1243async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1244 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1245 let fs = FakeFs::new(cx.executor());
1246 fs.insert_tree(
1247 "/root",
1248 json!({
1249 "foo.md": "Hello!\nHow\nBye\n"
1250 }),
1251 )
1252 .await;
1253 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1254
1255 let buffer = project
1256 .update(cx, |project, cx| {
1257 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1258 project.open_buffer(path, cx)
1259 })
1260 .await
1261 .unwrap();
1262 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1263 let position = snapshot.anchor_before(language::Point::new(1, 3));
1264
1265 // start two refresh tasks
1266 ep_store.update(cx, |ep_store, cx| {
1267 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1268 });
1269
1270 let (request1, respond_first) = requests.predict.next().await.unwrap();
1271
1272 ep_store.update(cx, |ep_store, cx| {
1273 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1274 });
1275
1276 let (request2, respond_second) = requests.predict.next().await.unwrap();
1277
1278 // wait for throttle, so requests are sent
1279 cx.run_until_parked();
1280
1281 ep_store.update(cx, |ep_store, cx| {
1282 // start a third request
1283 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1284
1285 // 2 are pending, so 2nd is cancelled
1286 assert_eq!(
1287 ep_store
1288 .get_or_init_project(&project, cx)
1289 .cancelled_predictions
1290 .iter()
1291 .copied()
1292 .collect::<Vec<_>>(),
1293 [1]
1294 );
1295 });
1296
1297 // wait for throttle
1298 cx.run_until_parked();
1299
1300 let (request3, respond_third) = requests.predict.next().await.unwrap();
1301
1302 let first_response = model_response(&request1, SIMPLE_DIFF);
1303 let first_id = first_response.request_id.clone();
1304 respond_first.send(first_response).unwrap();
1305
1306 cx.run_until_parked();
1307
1308 ep_store.update(cx, |ep_store, cx| {
1309 // current prediction is first
1310 assert_eq!(
1311 ep_store
1312 .prediction_at(&buffer, None, &project, cx)
1313 .unwrap()
1314 .id
1315 .0,
1316 first_id
1317 );
1318 });
1319
1320 let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1321 let cancelled_id = cancelled_response.request_id.clone();
1322 respond_second.send(cancelled_response).unwrap();
1323
1324 cx.run_until_parked();
1325
1326 ep_store.update(cx, |ep_store, cx| {
1327 // current prediction is still first, since second was cancelled
1328 assert_eq!(
1329 ep_store
1330 .prediction_at(&buffer, None, &project, cx)
1331 .unwrap()
1332 .id
1333 .0,
1334 first_id
1335 );
1336 });
1337
1338 let third_response = model_response(&request3, SIMPLE_DIFF);
1339 let third_response_id = third_response.request_id.clone();
1340 respond_third.send(third_response).unwrap();
1341
1342 cx.run_until_parked();
1343
1344 ep_store.update(cx, |ep_store, cx| {
1345 // third completes and replaces first
1346 assert_eq!(
1347 ep_store
1348 .prediction_at(&buffer, None, &project, cx)
1349 .unwrap()
1350 .id
1351 .0,
1352 third_response_id
1353 );
1354 });
1355
1356 // second is reported as rejected
1357 let (reject_request, _) = requests.reject.next().await.unwrap();
1358
1359 cx.run_until_parked();
1360
1361 assert_eq!(
1362 &reject_request.rejections,
1363 &[
1364 EditPredictionRejection {
1365 request_id: cancelled_id,
1366 reason: EditPredictionRejectReason::Canceled,
1367 was_shown: false
1368 },
1369 EditPredictionRejection {
1370 request_id: first_id,
1371 reason: EditPredictionRejectReason::Replaced,
1372 was_shown: false
1373 }
1374 ]
1375 );
1376}
1377
1378#[gpui::test]
1379async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
1380 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1381
1382 let fs = FakeFs::new(cx.executor());
1383 fs.insert_tree(
1384 "/root",
1385 json!({
1386 "foo.md": "Hello!\nHow\nBye\n",
1387 "bar.md": "Hola!\nComo\nAdios\n"
1388 }),
1389 )
1390 .await;
1391 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1392
1393 let buffer = project
1394 .update(cx, |project, cx| {
1395 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1396 project.set_active_path(Some(path.clone()), cx);
1397 project.open_buffer(path, cx)
1398 })
1399 .await
1400 .unwrap();
1401 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1402 let position = snapshot.anchor_before(language::Point::new(1, 3));
1403
1404 ep_store.update(cx, |ep_store, cx| {
1405 ep_store.register_project(&project, cx);
1406 ep_store.register_buffer(&buffer, &project, cx);
1407 });
1408
1409 // First edit request - no prior edit, so not throttled.
1410 ep_store.update(cx, |ep_store, cx| {
1411 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1412 });
1413 let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
1414 edit_response_tx.send(empty_response()).unwrap();
1415 cx.run_until_parked();
1416
1417 let diagnostic = lsp::Diagnostic {
1418 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1419 severity: Some(lsp::DiagnosticSeverity::ERROR),
1420 message: "Sentence is incomplete".to_string(),
1421 ..Default::default()
1422 };
1423
1424 // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
1425 project.update(cx, |project, cx| {
1426 project.lsp_store().update(cx, |lsp_store, cx| {
1427 lsp_store
1428 .update_diagnostics(
1429 LanguageServerId(0),
1430 lsp::PublishDiagnosticsParams {
1431 uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
1432 diagnostics: vec![diagnostic],
1433 version: None,
1434 },
1435 None,
1436 language::DiagnosticSourceKind::Pushed,
1437 &[],
1438 cx,
1439 )
1440 .unwrap();
1441 });
1442 });
1443 let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
1444 jump_response_tx.send(empty_response()).unwrap();
1445 cx.run_until_parked();
1446
1447 // Second edit request - should be throttled by the first edit.
1448 ep_store.update(cx, |ep_store, cx| {
1449 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1450 });
1451 assert_no_predict_request_ready(&mut requests.predict);
1452
1453 // Second jump request - should be throttled by the first jump.
1454 ep_store.update(cx, |ep_store, cx| {
1455 ep_store.refresh_prediction_from_diagnostics(
1456 project.clone(),
1457 DiagnosticSearchScope::Global,
1458 cx,
1459 );
1460 });
1461 assert_no_predict_request_ready(&mut requests.predict);
1462
1463 // Wait for both throttles to expire.
1464 cx.background_executor
1465 .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
1466 cx.background_executor.run_until_parked();
1467 cx.run_until_parked();
1468
1469 // Both requests should now go through.
1470 let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
1471 response_tx_1.send(empty_response()).unwrap();
1472 cx.run_until_parked();
1473
1474 let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
1475 response_tx_2.send(empty_response()).unwrap();
1476 cx.run_until_parked();
1477}
1478
1479#[gpui::test]
1480async fn test_rejections_flushing(cx: &mut TestAppContext) {
1481 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1482
1483 ep_store.update(cx, |ep_store, cx| {
1484 ep_store.reject_prediction(
1485 EditPredictionId("test-1".into()),
1486 EditPredictionRejectReason::Discarded,
1487 false,
1488 cx,
1489 );
1490 ep_store.reject_prediction(
1491 EditPredictionId("test-2".into()),
1492 EditPredictionRejectReason::Canceled,
1493 true,
1494 cx,
1495 );
1496 });
1497
1498 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1499 cx.run_until_parked();
1500
1501 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1502 respond_tx.send(()).unwrap();
1503
1504 // batched
1505 assert_eq!(reject_request.rejections.len(), 2);
1506 assert_eq!(
1507 reject_request.rejections[0],
1508 EditPredictionRejection {
1509 request_id: "test-1".to_string(),
1510 reason: EditPredictionRejectReason::Discarded,
1511 was_shown: false
1512 }
1513 );
1514 assert_eq!(
1515 reject_request.rejections[1],
1516 EditPredictionRejection {
1517 request_id: "test-2".to_string(),
1518 reason: EditPredictionRejectReason::Canceled,
1519 was_shown: true
1520 }
1521 );
1522
1523 // Reaching batch size limit sends without debounce
1524 ep_store.update(cx, |ep_store, cx| {
1525 for i in 0..70 {
1526 ep_store.reject_prediction(
1527 EditPredictionId(format!("batch-{}", i).into()),
1528 EditPredictionRejectReason::Discarded,
1529 false,
1530 cx,
1531 );
1532 }
1533 });
1534
1535 // First MAX/2 items are sent immediately
1536 cx.run_until_parked();
1537 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1538 respond_tx.send(()).unwrap();
1539
1540 assert_eq!(reject_request.rejections.len(), 50);
1541 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1542 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1543
1544 // Remaining items are debounced with the next batch
1545 cx.executor().advance_clock(Duration::from_secs(15));
1546 cx.run_until_parked();
1547
1548 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1549 respond_tx.send(()).unwrap();
1550
1551 assert_eq!(reject_request.rejections.len(), 20);
1552 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1553 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1554
1555 // Request failure
1556 ep_store.update(cx, |ep_store, cx| {
1557 ep_store.reject_prediction(
1558 EditPredictionId("retry-1".into()),
1559 EditPredictionRejectReason::Discarded,
1560 false,
1561 cx,
1562 );
1563 });
1564
1565 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1566 cx.run_until_parked();
1567
1568 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1569 assert_eq!(reject_request.rejections.len(), 1);
1570 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1571 // Simulate failure
1572 drop(_respond_tx);
1573
1574 // Add another rejection
1575 ep_store.update(cx, |ep_store, cx| {
1576 ep_store.reject_prediction(
1577 EditPredictionId("retry-2".into()),
1578 EditPredictionRejectReason::Discarded,
1579 false,
1580 cx,
1581 );
1582 });
1583
1584 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1585 cx.run_until_parked();
1586
1587 // Retry should include both the failed item and the new one
1588 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1589 respond_tx.send(()).unwrap();
1590
1591 assert_eq!(reject_request.rejections.len(), 2);
1592 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1593 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1594}
1595
1596// Skipped until we start including diagnostics in prompt
1597// #[gpui::test]
1598// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1599// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1600// let fs = FakeFs::new(cx.executor());
1601// fs.insert_tree(
1602// "/root",
1603// json!({
1604// "foo.md": "Hello!\nBye"
1605// }),
1606// )
1607// .await;
1608// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1609
1610// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1611// let diagnostic = lsp::Diagnostic {
1612// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1613// severity: Some(lsp::DiagnosticSeverity::ERROR),
1614// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1615// ..Default::default()
1616// };
1617
1618// project.update(cx, |project, cx| {
1619// project.lsp_store().update(cx, |lsp_store, cx| {
1620// // Create some diagnostics
1621// lsp_store
1622// .update_diagnostics(
1623// LanguageServerId(0),
1624// lsp::PublishDiagnosticsParams {
1625// uri: path_to_buffer_uri.clone(),
1626// diagnostics: vec![diagnostic],
1627// version: None,
1628// },
1629// None,
1630// language::DiagnosticSourceKind::Pushed,
1631// &[],
1632// cx,
1633// )
1634// .unwrap();
1635// });
1636// });
1637
1638// let buffer = project
1639// .update(cx, |project, cx| {
1640// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1641// project.open_buffer(path, cx)
1642// })
1643// .await
1644// .unwrap();
1645
1646// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1647// let position = snapshot.anchor_before(language::Point::new(0, 0));
1648
1649// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1650// ep_store.request_prediction(&project, &buffer, position, cx)
1651// });
1652
1653// let (request, _respond_tx) = req_rx.next().await.unwrap();
1654
1655// assert_eq!(request.diagnostic_groups.len(), 1);
1656// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1657// .unwrap();
1658// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1659// assert_eq!(
1660// value,
1661// json!({
1662// "entries": [{
1663// "range": {
1664// "start": 8,
1665// "end": 10
1666// },
1667// "diagnostic": {
1668// "source": null,
1669// "code": null,
1670// "code_description": null,
1671// "severity": 1,
1672// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1673// "markdown": null,
1674// "group_id": 0,
1675// "is_primary": true,
1676// "is_disk_based": false,
1677// "is_unnecessary": false,
1678// "source_kind": "Pushed",
1679// "data": null,
1680// "underline": true
1681// }
1682// }],
1683// "primary_ix": 0
1684// })
1685// );
1686// }
1687
1688// Generate a model response that would apply the given diff to the active file.
1689fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1690 let excerpt =
1691 request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1692 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1693
1694 PredictEditsV3Response {
1695 request_id: Uuid::new_v4().to_string(),
1696 output: new_excerpt,
1697 }
1698}
1699
1700fn empty_response() -> PredictEditsV3Response {
1701 PredictEditsV3Response {
1702 request_id: Uuid::new_v4().to_string(),
1703 output: String::new(),
1704 }
1705}
1706
1707fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1708 zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1709}
1710
1711fn assert_no_predict_request_ready(
1712 requests: &mut mpsc::UnboundedReceiver<(
1713 PredictEditsV3Request,
1714 oneshot::Sender<PredictEditsV3Response>,
1715 )>,
1716) {
1717 if requests.next().now_or_never().flatten().is_some() {
1718 panic!("Unexpected prediction request while throttled.");
1719 }
1720}
1721
1722struct RequestChannels {
1723 predict: mpsc::UnboundedReceiver<(
1724 PredictEditsV3Request,
1725 oneshot::Sender<PredictEditsV3Response>,
1726 )>,
1727 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1728}
1729
1730fn init_test_with_fake_client(
1731 cx: &mut TestAppContext,
1732) -> (Entity<EditPredictionStore>, RequestChannels) {
1733 cx.update(move |cx| {
1734 let settings_store = SettingsStore::test(cx);
1735 cx.set_global(settings_store);
1736 zlog::init_test();
1737
1738 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1739 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1740
1741 let http_client = FakeHttpClient::create({
1742 move |req| {
1743 let uri = req.uri().path().to_string();
1744 let mut body = req.into_body();
1745 let predict_req_tx = predict_req_tx.clone();
1746 let reject_req_tx = reject_req_tx.clone();
1747 async move {
1748 let resp = match uri.as_str() {
1749 "/client/llm_tokens" => serde_json::to_string(&json!({
1750 "token": "test"
1751 }))
1752 .unwrap(),
1753 "/predict_edits/v3" => {
1754 let mut buf = Vec::new();
1755 body.read_to_end(&mut buf).await.ok();
1756 let decompressed = zstd::decode_all(&buf[..]).unwrap();
1757 let req = serde_json::from_slice(&decompressed).unwrap();
1758
1759 let (res_tx, res_rx) = oneshot::channel();
1760 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1761 serde_json::to_string(&res_rx.await?).unwrap()
1762 }
1763 "/predict_edits/reject" => {
1764 let mut buf = Vec::new();
1765 body.read_to_end(&mut buf).await.ok();
1766 let req = serde_json::from_slice(&buf).unwrap();
1767
1768 let (res_tx, res_rx) = oneshot::channel();
1769 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1770 serde_json::to_string(&res_rx.await?).unwrap()
1771 }
1772 _ => {
1773 panic!("Unexpected path: {}", uri)
1774 }
1775 };
1776
1777 Ok(Response::builder().body(resp.into()).unwrap())
1778 }
1779 }
1780 });
1781
1782 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1783 client.cloud_client().set_credentials(1, "test".into());
1784
1785 language_model::init(client.clone(), cx);
1786
1787 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1788 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1789
1790 (
1791 ep_store,
1792 RequestChannels {
1793 predict: predict_req_rx,
1794 reject: reject_req_rx,
1795 },
1796 )
1797 })
1798}
1799
1800#[gpui::test]
1801async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1802 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1803 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1804 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1805 });
1806
1807 let edit_preview = cx
1808 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1809 .await;
1810
1811 let prediction = EditPrediction {
1812 edits,
1813 cursor_position: None,
1814 edit_preview,
1815 buffer: buffer.clone(),
1816 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1817 id: EditPredictionId("the-id".into()),
1818 inputs: ZetaPromptInput {
1819 events: Default::default(),
1820 related_files: Default::default(),
1821 cursor_path: Path::new("").into(),
1822 cursor_excerpt: "".into(),
1823 editable_range_in_excerpt: 0..0,
1824 cursor_offset_in_excerpt: 0,
1825 excerpt_start_row: None,
1826 excerpt_ranges: None,
1827 preferred_model: None,
1828 in_open_source_repo: false,
1829 can_collect_data: false,
1830 },
1831 buffer_snapshotted_at: Instant::now(),
1832 response_received_at: Instant::now(),
1833 };
1834
1835 cx.update(|cx| {
1836 assert_eq!(
1837 from_completion_edits(
1838 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1839 &buffer,
1840 cx
1841 ),
1842 vec![(2..5, "REM".into()), (9..11, "".into())]
1843 );
1844
1845 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1846 assert_eq!(
1847 from_completion_edits(
1848 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1849 &buffer,
1850 cx
1851 ),
1852 vec![(2..2, "REM".into()), (6..8, "".into())]
1853 );
1854
1855 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1856 assert_eq!(
1857 from_completion_edits(
1858 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1859 &buffer,
1860 cx
1861 ),
1862 vec![(2..5, "REM".into()), (9..11, "".into())]
1863 );
1864
1865 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1866 assert_eq!(
1867 from_completion_edits(
1868 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1869 &buffer,
1870 cx
1871 ),
1872 vec![(3..3, "EM".into()), (7..9, "".into())]
1873 );
1874
1875 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1876 assert_eq!(
1877 from_completion_edits(
1878 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1879 &buffer,
1880 cx
1881 ),
1882 vec![(4..4, "M".into()), (8..10, "".into())]
1883 );
1884
1885 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1886 assert_eq!(
1887 from_completion_edits(
1888 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1889 &buffer,
1890 cx
1891 ),
1892 vec![(9..11, "".into())]
1893 );
1894
1895 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1896 assert_eq!(
1897 from_completion_edits(
1898 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1899 &buffer,
1900 cx
1901 ),
1902 vec![(4..4, "M".into()), (8..10, "".into())]
1903 );
1904
1905 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1906 assert_eq!(
1907 from_completion_edits(
1908 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1909 &buffer,
1910 cx
1911 ),
1912 vec![(4..4, "M".into())]
1913 );
1914
1915 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1916 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1917 })
1918}
1919
1920#[gpui::test]
1921async fn test_clean_up_diff(cx: &mut TestAppContext) {
1922 init_test(cx);
1923
1924 assert_eq!(
1925 apply_edit_prediction(
1926 indoc! {"
1927 fn main() {
1928 let word_1 = \"lorem\";
1929 let range = word.len()..word.len();
1930 }
1931 "},
1932 indoc! {"
1933 fn main() {
1934 let word_1 = \"lorem\";
1935 let range = word_1.len()..word_1.len();
1936 }
1937 "},
1938 cx,
1939 )
1940 .await,
1941 indoc! {"
1942 fn main() {
1943 let word_1 = \"lorem\";
1944 let range = word_1.len()..word_1.len();
1945 }
1946 "},
1947 );
1948
1949 assert_eq!(
1950 apply_edit_prediction(
1951 indoc! {"
1952 fn main() {
1953 let story = \"the quick\"
1954 }
1955 "},
1956 indoc! {"
1957 fn main() {
1958 let story = \"the quick brown fox jumps over the lazy dog\";
1959 }
1960 "},
1961 cx,
1962 )
1963 .await,
1964 indoc! {"
1965 fn main() {
1966 let story = \"the quick brown fox jumps over the lazy dog\";
1967 }
1968 "},
1969 );
1970}
1971
1972#[gpui::test]
1973async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1974 init_test(cx);
1975
1976 let buffer_content = "lorem\n";
1977 let completion_response = "lorem\nipsum\n";
1978
1979 assert_eq!(
1980 apply_edit_prediction(buffer_content, completion_response, cx).await,
1981 "lorem\nipsum\n"
1982 );
1983}
1984
1985#[gpui::test]
1986async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1987 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1988 // When the buffer ends without a trailing newline, but the model returns output
1989 // with a trailing newline, zeta2 should normalize both sides before diffing
1990 // so no spurious newline is inserted.
1991 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1992 let fs = FakeFs::new(cx.executor());
1993
1994 // Single line buffer with no trailing newline
1995 fs.insert_tree(
1996 "/root",
1997 json!({
1998 "foo.txt": "hello"
1999 }),
2000 )
2001 .await;
2002 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2003
2004 let buffer = project
2005 .update(cx, |project, cx| {
2006 let path = project
2007 .find_project_path(path!("root/foo.txt"), cx)
2008 .unwrap();
2009 project.open_buffer(path, cx)
2010 })
2011 .await
2012 .unwrap();
2013
2014 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2015 let position = snapshot.anchor_before(language::Point::new(0, 5));
2016
2017 ep_store.update(cx, |ep_store, cx| {
2018 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2019 });
2020
2021 let (_request, respond_tx) = requests.predict.next().await.unwrap();
2022
2023 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
2024 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
2025 let response = PredictEditsV3Response {
2026 request_id: Uuid::new_v4().to_string(),
2027 output: "hello world\n".to_string(),
2028 };
2029 respond_tx.send(response).unwrap();
2030
2031 cx.run_until_parked();
2032
2033 // The prediction should insert " world" without adding a newline
2034 ep_store.update(cx, |ep_store, cx| {
2035 let prediction = ep_store
2036 .prediction_at(&buffer, None, &project, cx)
2037 .expect("should have prediction");
2038 let edits: Vec<_> = prediction
2039 .edits
2040 .iter()
2041 .map(|(range, text)| {
2042 let snapshot = buffer.read(cx).snapshot();
2043 (range.to_offset(&snapshot), text.clone())
2044 })
2045 .collect();
2046 assert_eq!(edits, vec![(5..5, " world".into())]);
2047 });
2048}
2049
2050fn init_test(cx: &mut TestAppContext) {
2051 cx.update(|cx| {
2052 let settings_store = SettingsStore::test(cx);
2053 cx.set_global(settings_store);
2054 });
2055}
2056
2057async fn apply_edit_prediction(
2058 buffer_content: &str,
2059 completion_response: &str,
2060 cx: &mut TestAppContext,
2061) -> String {
2062 let fs = project::FakeFs::new(cx.executor());
2063 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2064 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2065 let (ep_store, response) = make_test_ep_store(&project, cx).await;
2066 *response.lock() = completion_response.to_string();
2067 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2068 buffer.update(cx, |buffer, cx| {
2069 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2070 });
2071 buffer.read_with(cx, |buffer, _| buffer.text())
2072}
2073
2074async fn run_edit_prediction(
2075 buffer: &Entity<Buffer>,
2076 project: &Entity<Project>,
2077 ep_store: &Entity<EditPredictionStore>,
2078 cx: &mut TestAppContext,
2079) -> EditPrediction {
2080 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2081 ep_store.update(cx, |ep_store, cx| {
2082 ep_store.register_buffer(buffer, &project, cx)
2083 });
2084 cx.background_executor.run_until_parked();
2085 let prediction_task = ep_store.update(cx, |ep_store, cx| {
2086 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2087 });
2088 prediction_task.await.unwrap().unwrap().prediction.unwrap()
2089}
2090
2091async fn make_test_ep_store(
2092 project: &Entity<Project>,
2093 cx: &mut TestAppContext,
2094) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
2095 let default_response = "hello world\n".to_string();
2096 let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
2097 let http_client = FakeHttpClient::create({
2098 let completion_response = completion_response.clone();
2099 let mut next_request_id = 0;
2100 move |req| {
2101 let completion_response = completion_response.clone();
2102 async move {
2103 match (req.method(), req.uri().path()) {
2104 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2105 .status(200)
2106 .body(
2107 serde_json::to_string(&CreateLlmTokenResponse {
2108 token: LlmToken("the-llm-token".to_string()),
2109 })
2110 .unwrap()
2111 .into(),
2112 )
2113 .unwrap()),
2114 (&Method::POST, "/predict_edits/v3") => {
2115 next_request_id += 1;
2116 Ok(http_client::Response::builder()
2117 .status(200)
2118 .body(
2119 serde_json::to_string(&PredictEditsV3Response {
2120 request_id: format!("request-{next_request_id}"),
2121 output: completion_response.lock().clone(),
2122 })
2123 .unwrap()
2124 .into(),
2125 )
2126 .unwrap())
2127 }
2128 _ => Ok(http_client::Response::builder()
2129 .status(404)
2130 .body("Not Found".into())
2131 .unwrap()),
2132 }
2133 }
2134 }
2135 });
2136
2137 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2138 cx.update(|cx| {
2139 RefreshLlmTokenListener::register(client.clone(), cx);
2140 });
2141 let _server = FakeServer::for_client(42, &client, cx).await;
2142
2143 let ep_store = cx.new(|cx| {
2144 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2145 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2146
2147 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2148 for worktree in worktrees {
2149 let worktree_id = worktree.read(cx).id();
2150 ep_store
2151 .get_or_init_project(project, cx)
2152 .license_detection_watchers
2153 .entry(worktree_id)
2154 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2155 }
2156
2157 ep_store
2158 });
2159
2160 (ep_store, completion_response)
2161}
2162
2163fn to_completion_edits(
2164 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2165 buffer: &Entity<Buffer>,
2166 cx: &App,
2167) -> Vec<(Range<Anchor>, Arc<str>)> {
2168 let buffer = buffer.read(cx);
2169 iterator
2170 .into_iter()
2171 .map(|(range, text)| {
2172 (
2173 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2174 text,
2175 )
2176 })
2177 .collect()
2178}
2179
2180fn from_completion_edits(
2181 editor_edits: &[(Range<Anchor>, Arc<str>)],
2182 buffer: &Entity<Buffer>,
2183 cx: &App,
2184) -> Vec<(Range<usize>, Arc<str>)> {
2185 let buffer = buffer.read(cx);
2186 editor_edits
2187 .iter()
2188 .map(|(range, text)| {
2189 (
2190 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2191 text.clone(),
2192 )
2193 })
2194 .collect()
2195}
2196
2197#[gpui::test]
2198async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2199 init_test(cx);
2200
2201 let fs = FakeFs::new(cx.executor());
2202 fs.insert_tree(
2203 "/project",
2204 serde_json::json!({
2205 "main.rs": "fn main() {\n \n}\n"
2206 }),
2207 )
2208 .await;
2209
2210 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2211
2212 let http_client = FakeHttpClient::create(|_req| async move {
2213 Ok(gpui::http_client::Response::builder()
2214 .status(401)
2215 .body("Unauthorized".into())
2216 .unwrap())
2217 });
2218
2219 let client =
2220 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2221 cx.update(|cx| {
2222 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2223 });
2224
2225 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2226
2227 let buffer = project
2228 .update(cx, |project, cx| {
2229 let path = project
2230 .find_project_path(path!("/project/main.rs"), cx)
2231 .unwrap();
2232 project.open_buffer(path, cx)
2233 })
2234 .await
2235 .unwrap();
2236
2237 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2238 ep_store.update(cx, |ep_store, cx| {
2239 ep_store.register_buffer(&buffer, &project, cx)
2240 });
2241 cx.background_executor.run_until_parked();
2242
2243 let completion_task = ep_store.update(cx, |ep_store, cx| {
2244 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2245 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2246 });
2247
2248 let result = completion_task.await;
2249 assert!(
2250 result.is_err(),
2251 "Without authentication and without custom URL, prediction should fail"
2252 );
2253}
2254
2255#[gpui::test]
2256fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2257 let buffer = cx.new(|cx| {
2258 Buffer::local(
2259 indoc! {"
2260 zero
2261 one
2262 two
2263 three
2264 four
2265 five
2266 six
2267 seven
2268 eight
2269 nine
2270 ten
2271 eleven
2272 twelve
2273 thirteen
2274 fourteen
2275 fifteen
2276 sixteen
2277 seventeen
2278 eighteen
2279 nineteen
2280 twenty
2281 twenty-one
2282 twenty-two
2283 twenty-three
2284 twenty-four
2285 "},
2286 cx,
2287 )
2288 });
2289
2290 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2291
2292 buffer.update(cx, |buffer, cx| {
2293 let point = Point::new(12, 0);
2294 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2295 let point = Point::new(8, 0);
2296 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2297 });
2298
2299 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2300
2301 let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2302
2303 assert_eq!(
2304 diff,
2305 indoc! {"
2306 @@ -6,10 +6,12 @@
2307 five
2308 six
2309 seven
2310 +FIRST INSERTION
2311 eight
2312 nine
2313 ten
2314 eleven
2315 +SECOND INSERTION
2316 twelve
2317 thirteen
2318 fourteen
2319 "}
2320 );
2321}
2322
2323#[gpui::test]
2324async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
2325 fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
2326 let collab_replica = clock::ReplicaId::new(10);
2327 let anchor = buffer.read_with(cx, |buffer, _| {
2328 buffer.snapshot().anchor_before(Point::new(row, 0))
2329 });
2330 let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
2331 id: 1,
2332 start: anchor,
2333 end: anchor,
2334 reversed: false,
2335 goal: SelectionGoal::None,
2336 }]);
2337 buffer.update(cx, |buffer, cx| {
2338 buffer.apply_ops(
2339 [Operation::UpdateSelections {
2340 selections,
2341 lamport_timestamp: clock::Lamport {
2342 replica_id: collab_replica,
2343 value: 1,
2344 },
2345 line_mode: false,
2346 cursor_shape: CursorShape::Bar,
2347 }],
2348 cx,
2349 );
2350 });
2351 }
2352
2353 fn publish_diagnostics(
2354 uri_path: &'static str,
2355 rows: &[u32],
2356 project: &Entity<Project>,
2357 cx: &mut TestAppContext,
2358 ) {
2359 let diagnostics: Vec<_> = rows
2360 .iter()
2361 .map(|&row| lsp::Diagnostic {
2362 range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
2363 severity: Some(lsp::DiagnosticSeverity::ERROR),
2364 message: format!("error at row {row}"),
2365 ..Default::default()
2366 })
2367 .collect();
2368 project.update(cx, |project, cx| {
2369 project.lsp_store().update(cx, |lsp_store, cx| {
2370 lsp_store
2371 .update_diagnostics(
2372 LanguageServerId(0),
2373 lsp::PublishDiagnosticsParams {
2374 uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
2375 diagnostics,
2376 version: None,
2377 },
2378 None,
2379 language::DiagnosticSourceKind::Pushed,
2380 &[],
2381 cx,
2382 )
2383 .expect("failed to update diagnostics");
2384 });
2385 });
2386 }
2387
2388 init_test(cx);
2389
2390 let mut lines = String::new();
2391 for i in 0..60 {
2392 lines.push_str(&format!("line {i}\n"));
2393 }
2394
2395 let fs = FakeFs::new(cx.executor());
2396 fs.insert_tree(
2397 "/root",
2398 json!({
2399 "active.txt": lines,
2400 "collab_file.txt": "error here\nsecond line\n",
2401 "free_file.txt": "another error\nsecond line\n",
2402 }),
2403 )
2404 .await;
2405 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2406
2407 let active_buffer = project
2408 .update(cx, |project, cx| {
2409 let path = project
2410 .find_project_path(path!("/root/active.txt"), cx)
2411 .expect("active.txt not found");
2412 project.set_active_path(Some(path.clone()), cx);
2413 project.open_buffer(path, cx)
2414 })
2415 .await
2416 .expect("failed to open active buffer");
2417
2418 set_collaborator_cursor(&active_buffer, 5, cx);
2419
2420 publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
2421
2422 cx.run_until_parked();
2423
2424 let cursor_point = Point::new(25, 0);
2425 let empty_search_range: Range<Point> = Default::default();
2426
2427 let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2428 let result = EditPredictionStore::next_diagnostic_location(
2429 active_buffer.clone(),
2430 &snapshot,
2431 empty_search_range.clone(),
2432 cursor_point,
2433 &project,
2434 &mut cx.to_async(),
2435 )
2436 .await
2437 .expect("next_diagnostic_location failed");
2438
2439 let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
2440 assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
2441 let result_row = result_buffer.read_with(cx, |buffer, _| {
2442 result_anchor.to_point(&buffer.snapshot()).row
2443 });
2444 assert_ne!(
2445 result_row, 3,
2446 "row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
2447 );
2448 assert!(
2449 result_row == 25 || result_row == 50,
2450 "expected row 25 or 50, got {result_row}"
2451 );
2452
2453 let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2454 let near_cursor_point = Point::new(4, 0);
2455 let result_near = EditPredictionStore::next_diagnostic_location(
2456 active_buffer.clone(),
2457 &snapshot_near,
2458 empty_search_range.clone(),
2459 near_cursor_point,
2460 &project,
2461 &mut cx.to_async(),
2462 )
2463 .await
2464 .expect("next_diagnostic_location failed");
2465
2466 let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
2467 let near_row =
2468 active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
2469 assert_eq!(
2470 near_row, 3,
2471 "row 3 should be included when local cursor (row 4) is also near the collaborator"
2472 );
2473
2474 let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2475 let far_cursor_point = Point::new(50, 0);
2476 let result_far = EditPredictionStore::next_diagnostic_location(
2477 active_buffer.clone(),
2478 &snapshot_far,
2479 empty_search_range.clone(),
2480 far_cursor_point,
2481 &project,
2482 &mut cx.to_async(),
2483 )
2484 .await
2485 .expect("next_diagnostic_location failed");
2486
2487 let (_, far_anchor) = result_far.expect("expected a diagnostic location");
2488 let far_row =
2489 active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
2490 assert_eq!(
2491 far_row, 50,
2492 "row 50 is near local cursor (row 50) and far from collaborator, should be picked"
2493 );
2494
2495 publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
2496 publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
2497 cx.run_until_parked();
2498
2499 let collab_buffer = project
2500 .update(cx, |project, cx| {
2501 let path = project
2502 .find_project_path(path!("/root/collab_file.txt"), cx)
2503 .expect("collab_file.txt not found");
2504 project.open_buffer(path, cx)
2505 })
2506 .await
2507 .expect("failed to open collab buffer");
2508
2509 set_collaborator_cursor(&collab_buffer, 0, cx);
2510 cx.run_until_parked();
2511
2512 let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
2513 let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2514 let result_cross = EditPredictionStore::next_diagnostic_location(
2515 active_buffer.clone(),
2516 &snapshot_cross,
2517 no_same_file_search_range,
2518 Point::new(0, 0),
2519 &project,
2520 &mut cx.to_async(),
2521 )
2522 .await
2523 .expect("cross-file next_diagnostic_location failed");
2524
2525 let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
2526 let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
2527 buffer
2528 .file()
2529 .expect("buffer should have a file")
2530 .full_path(cx)
2531 });
2532 assert_eq!(
2533 cross_path,
2534 Path::new(path!("root/free_file.txt")),
2535 "should skip collab_file.txt (has collaborator) and pick free_file.txt"
2536 );
2537}
2538
2539#[ctor::ctor]
2540fn init_logger() {
2541 zlog::init_test();
2542}