1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
3use client::{UserStore, test::FakeServer};
4use clock::FakeSystemClock;
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
8 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
9};
10use futures::{
11 AsyncReadExt, FutureExt, StreamExt,
12 channel::{mpsc, oneshot},
13};
14use gpui::App;
15use gpui::{
16 Entity, TestAppContext,
17 http_client::{FakeHttpClient, Response},
18};
19use indoc::indoc;
20use language::{Anchor, Buffer, CursorShape, Operation, Point, Selection, SelectionGoal};
21use lsp::LanguageServerId;
22use parking_lot::Mutex;
23use pretty_assertions::{assert_eq, assert_matches};
24use project::{FakeFs, Project};
25use serde_json::json;
26use settings::SettingsStore;
27use std::{path::Path, sync::Arc, time::Duration};
28use util::path;
29use uuid::Uuid;
30use zeta_prompt::ZetaPromptInput;
31
32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
33
34#[gpui::test]
35async fn test_current_state(cx: &mut TestAppContext) {
36 let (ep_store, mut requests) = init_test_with_fake_client(cx);
37 let fs = FakeFs::new(cx.executor());
38 fs.insert_tree(
39 "/root",
40 json!({
41 "1.txt": "Hello!\nHow\nBye\n",
42 "2.txt": "Hola!\nComo\nAdios\n"
43 }),
44 )
45 .await;
46 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
47
48 let buffer1 = project
49 .update(cx, |project, cx| {
50 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
51 project.set_active_path(Some(path.clone()), cx);
52 project.open_buffer(path, cx)
53 })
54 .await
55 .unwrap();
56 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
57 let position = snapshot1.anchor_before(language::Point::new(1, 3));
58
59 ep_store.update(cx, |ep_store, cx| {
60 ep_store.register_project(&project, cx);
61 ep_store.register_buffer(&buffer1, &project, cx);
62 });
63
64 // Prediction for current file
65
66 ep_store.update(cx, |ep_store, cx| {
67 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
68 });
69 let (request, respond_tx) = requests.predict.next().await.unwrap();
70
71 respond_tx
72 .send(model_response(
73 &request,
74 indoc! {r"
75 --- a/root/1.txt
76 +++ b/root/1.txt
77 @@ ... @@
78 Hello!
79 -How
80 +How are you?
81 Bye
82 "},
83 ))
84 .unwrap();
85
86 cx.run_until_parked();
87
88 ep_store.update(cx, |ep_store, cx| {
89 let prediction = ep_store
90 .prediction_at(&buffer1, None, &project, cx)
91 .unwrap();
92 assert_matches!(prediction, BufferEditPrediction::Local { .. });
93 });
94
95 ep_store.update(cx, |ep_store, cx| {
96 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
97 });
98
99 // Prediction for diagnostic in another file
100
101 let diagnostic = lsp::Diagnostic {
102 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
103 severity: Some(lsp::DiagnosticSeverity::ERROR),
104 message: "Sentence is incomplete".to_string(),
105 ..Default::default()
106 };
107
108 project.update(cx, |project, cx| {
109 project.lsp_store().update(cx, |lsp_store, cx| {
110 lsp_store
111 .update_diagnostics(
112 LanguageServerId(0),
113 lsp::PublishDiagnosticsParams {
114 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
115 diagnostics: vec![diagnostic],
116 version: None,
117 },
118 None,
119 language::DiagnosticSourceKind::Pushed,
120 &[],
121 cx,
122 )
123 .unwrap();
124 });
125 });
126
127 let (request, respond_tx) = requests.predict.next().await.unwrap();
128 respond_tx
129 .send(model_response(
130 &request,
131 indoc! {r#"
132 --- a/root/2.txt
133 +++ b/root/2.txt
134 @@ ... @@
135 Hola!
136 -Como
137 +Como estas?
138 Adios
139 "#},
140 ))
141 .unwrap();
142 cx.run_until_parked();
143
144 ep_store.update(cx, |ep_store, cx| {
145 let prediction = ep_store
146 .prediction_at(&buffer1, None, &project, cx)
147 .unwrap();
148 assert_matches!(
149 prediction,
150 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
151 );
152 });
153
154 let buffer2 = project
155 .update(cx, |project, cx| {
156 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
157 project.open_buffer(path, cx)
158 })
159 .await
160 .unwrap();
161
162 ep_store.update(cx, |ep_store, cx| {
163 let prediction = ep_store
164 .prediction_at(&buffer2, None, &project, cx)
165 .unwrap();
166 assert_matches!(prediction, BufferEditPrediction::Local { .. });
167 });
168}
169
170#[gpui::test]
171async fn test_simple_request(cx: &mut TestAppContext) {
172 let (ep_store, mut requests) = init_test_with_fake_client(cx);
173 let fs = FakeFs::new(cx.executor());
174 fs.insert_tree(
175 "/root",
176 json!({
177 "foo.md": "Hello!\nHow\nBye\n"
178 }),
179 )
180 .await;
181 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
182
183 let buffer = project
184 .update(cx, |project, cx| {
185 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
186 project.open_buffer(path, cx)
187 })
188 .await
189 .unwrap();
190 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
191 let position = snapshot.anchor_before(language::Point::new(1, 3));
192
193 let prediction_task = ep_store.update(cx, |ep_store, cx| {
194 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
195 });
196
197 let (request, respond_tx) = requests.predict.next().await.unwrap();
198
199 // TODO Put back when we have a structured request again
200 // assert_eq!(
201 // request.excerpt_path.as_ref(),
202 // Path::new(path!("root/foo.md"))
203 // );
204 // assert_eq!(
205 // request.cursor_point,
206 // Point {
207 // line: Line(1),
208 // column: 3
209 // }
210 // );
211
212 respond_tx
213 .send(model_response(
214 &request,
215 indoc! { r"
216 --- a/root/foo.md
217 +++ b/root/foo.md
218 @@ ... @@
219 Hello!
220 -How
221 +How are you?
222 Bye
223 "},
224 ))
225 .unwrap();
226
227 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
228
229 assert_eq!(prediction.edits.len(), 1);
230 assert_eq!(
231 prediction.edits[0].0.to_point(&snapshot).start,
232 language::Point::new(1, 3)
233 );
234 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
235}
236
237#[gpui::test]
238async fn test_request_events(cx: &mut TestAppContext) {
239 let (ep_store, mut requests) = init_test_with_fake_client(cx);
240 let fs = FakeFs::new(cx.executor());
241 fs.insert_tree(
242 "/root",
243 json!({
244 "foo.md": "Hello!\n\nBye\n"
245 }),
246 )
247 .await;
248 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
249
250 let buffer = project
251 .update(cx, |project, cx| {
252 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
253 project.open_buffer(path, cx)
254 })
255 .await
256 .unwrap();
257
258 ep_store.update(cx, |ep_store, cx| {
259 ep_store.register_buffer(&buffer, &project, cx);
260 });
261
262 buffer.update(cx, |buffer, cx| {
263 buffer.edit(vec![(7..7, "How")], None, cx);
264 });
265
266 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
267 let position = snapshot.anchor_before(language::Point::new(1, 3));
268
269 let prediction_task = ep_store.update(cx, |ep_store, cx| {
270 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
271 });
272
273 let (request, respond_tx) = requests.predict.next().await.unwrap();
274
275 let prompt = prompt_from_request(&request);
276 assert!(
277 prompt.contains(indoc! {"
278 --- a/root/foo.md
279 +++ b/root/foo.md
280 @@ -1,3 +1,3 @@
281 Hello!
282 -
283 +How
284 Bye
285 "}),
286 "{prompt}"
287 );
288
289 respond_tx
290 .send(model_response(
291 &request,
292 indoc! {r#"
293 --- a/root/foo.md
294 +++ b/root/foo.md
295 @@ ... @@
296 Hello!
297 -How
298 +How are you?
299 Bye
300 "#},
301 ))
302 .unwrap();
303
304 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
305
306 assert_eq!(prediction.edits.len(), 1);
307 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
308}
309
310#[gpui::test]
311async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
312 let (ep_store, _requests) = init_test_with_fake_client(cx);
313 let fs = FakeFs::new(cx.executor());
314 fs.insert_tree(
315 "/root",
316 json!({
317 "foo.md": "Hello!\n\nBye\n"
318 }),
319 )
320 .await;
321 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
322
323 let buffer = project
324 .update(cx, |project, cx| {
325 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
326 project.open_buffer(path, cx)
327 })
328 .await
329 .unwrap();
330
331 ep_store.update(cx, |ep_store, cx| {
332 ep_store.register_buffer(&buffer, &project, cx);
333 });
334
335 // First burst: insert "How"
336 buffer.update(cx, |buffer, cx| {
337 buffer.edit(vec![(7..7, "How")], None, cx);
338 });
339
340 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
341 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
342 cx.run_until_parked();
343
344 // Second burst: append " are you?" immediately after "How" on the same line.
345 //
346 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
347 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
348 buffer.update(cx, |buffer, cx| {
349 buffer.edit(vec![(10..10, " are you?")], None, cx);
350 });
351
352 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
353 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
354 buffer.update(cx, |buffer, cx| {
355 buffer.edit(vec![(19..19, "!")], None, cx);
356 });
357
358 // With time-based splitting, there are two distinct events.
359 let events = ep_store.update(cx, |ep_store, cx| {
360 ep_store.edit_history_for_project(&project, cx)
361 });
362 assert_eq!(events.len(), 2);
363 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
364 assert_eq!(
365 diff.as_str(),
366 indoc! {"
367 @@ -1,3 +1,3 @@
368 Hello!
369 -
370 +How
371 Bye
372 "}
373 );
374
375 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
376 assert_eq!(
377 diff.as_str(),
378 indoc! {"
379 @@ -1,3 +1,3 @@
380 Hello!
381 -How
382 +How are you?!
383 Bye
384 "}
385 );
386}
387
388#[gpui::test]
389async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
390 let (ep_store, _requests) = init_test_with_fake_client(cx);
391 let fs = FakeFs::new(cx.executor());
392
393 // Create a file with 30 lines to test line-based coalescing
394 let content = (1..=30)
395 .map(|i| format!("Line {}\n", i))
396 .collect::<String>();
397 fs.insert_tree(
398 "/root",
399 json!({
400 "foo.md": content
401 }),
402 )
403 .await;
404 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
405
406 let buffer = project
407 .update(cx, |project, cx| {
408 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
409 project.open_buffer(path, cx)
410 })
411 .await
412 .unwrap();
413
414 ep_store.update(cx, |ep_store, cx| {
415 ep_store.register_buffer(&buffer, &project, cx);
416 });
417
418 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
419 buffer.update(cx, |buffer, cx| {
420 let start = Point::new(10, 0).to_offset(buffer);
421 let end = Point::new(13, 0).to_offset(buffer);
422 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
423 });
424
425 let events = ep_store.update(cx, |ep_store, cx| {
426 ep_store.edit_history_for_project(&project, cx)
427 });
428 assert_eq!(
429 render_events(&events),
430 indoc! {"
431 @@ -8,9 +8,8 @@
432 Line 8
433 Line 9
434 Line 10
435 -Line 11
436 -Line 12
437 -Line 13
438 +Middle A
439 +Middle B
440 Line 14
441 Line 15
442 Line 16
443 "},
444 "After first edit"
445 );
446
447 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
448 // This tests that coalescing considers the START of the existing range
449 buffer.update(cx, |buffer, cx| {
450 let offset = Point::new(5, 0).to_offset(buffer);
451 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
452 });
453
454 let events = ep_store.update(cx, |ep_store, cx| {
455 ep_store.edit_history_for_project(&project, cx)
456 });
457 assert_eq!(
458 render_events(&events),
459 indoc! {"
460 @@ -3,14 +3,14 @@
461 Line 3
462 Line 4
463 Line 5
464 +Above
465 Line 6
466 Line 7
467 Line 8
468 Line 9
469 Line 10
470 -Line 11
471 -Line 12
472 -Line 13
473 +Middle A
474 +Middle B
475 Line 14
476 Line 15
477 Line 16
478 "},
479 "After inserting above (should coalesce)"
480 );
481
482 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
483 // This tests that coalescing considers the END of the existing range
484 buffer.update(cx, |buffer, cx| {
485 let offset = Point::new(14, 0).to_offset(buffer);
486 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
487 });
488
489 let events = ep_store.update(cx, |ep_store, cx| {
490 ep_store.edit_history_for_project(&project, cx)
491 });
492 assert_eq!(
493 render_events(&events),
494 indoc! {"
495 @@ -3,15 +3,16 @@
496 Line 3
497 Line 4
498 Line 5
499 +Above
500 Line 6
501 Line 7
502 Line 8
503 Line 9
504 Line 10
505 -Line 11
506 -Line 12
507 -Line 13
508 +Middle A
509 +Middle B
510 Line 14
511 +Below
512 Line 15
513 Line 16
514 Line 17
515 "},
516 "After inserting below (should coalesce)"
517 );
518
519 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
520 // This should NOT coalesce - creates a new event
521 buffer.update(cx, |buffer, cx| {
522 let offset = Point::new(25, 0).to_offset(buffer);
523 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
524 });
525
526 let events = ep_store.update(cx, |ep_store, cx| {
527 ep_store.edit_history_for_project(&project, cx)
528 });
529 assert_eq!(
530 render_events(&events),
531 indoc! {"
532 @@ -3,15 +3,16 @@
533 Line 3
534 Line 4
535 Line 5
536 +Above
537 Line 6
538 Line 7
539 Line 8
540 Line 9
541 Line 10
542 -Line 11
543 -Line 12
544 -Line 13
545 +Middle A
546 +Middle B
547 Line 14
548 +Below
549 Line 15
550 Line 16
551 Line 17
552
553 ---
554 @@ -23,6 +23,7 @@
555 Line 22
556 Line 23
557 Line 24
558 +Far below
559 Line 25
560 Line 26
561 Line 27
562 "},
563 "After inserting far below (should NOT coalesce)"
564 );
565}
566
567fn render_events(events: &[StoredEvent]) -> String {
568 events
569 .iter()
570 .map(|e| {
571 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
572 diff.as_str()
573 })
574 .collect::<Vec<_>>()
575 .join("\n---\n")
576}
577
578fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
579 events
580 .iter()
581 .map(|e| {
582 let zeta_prompt::Event::BufferChange {
583 diff, predicted, ..
584 } = e.event.as_ref();
585 let prefix = if *predicted { "predicted" } else { "manual" };
586 format!("{}\n{}", prefix, diff)
587 })
588 .collect()
589}
590
591#[gpui::test]
592async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
593 let (ep_store, _requests) = init_test_with_fake_client(cx);
594 let fs = FakeFs::new(cx.executor());
595 fs.insert_tree(
596 "/root",
597 json!({
598 "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
599 }),
600 )
601 .await;
602 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
603
604 let buffer = project
605 .update(cx, |project, cx| {
606 let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
607 project.open_buffer(path, cx)
608 })
609 .await
610 .unwrap();
611
612 ep_store.update(cx, |ep_store, cx| {
613 ep_store.register_buffer(&buffer, &project, cx);
614 });
615
616 // Case 1: Manual edits have `predicted` set to false.
617 buffer.update(cx, |buffer, cx| {
618 buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
619 });
620
621 let events = ep_store.update(cx, |ep_store, cx| {
622 ep_store.edit_history_for_project(&project, cx)
623 });
624
625 assert_eq!(
626 render_events_with_predicted(&events),
627 vec![indoc! {"
628 manual
629 @@ -1,4 +1,4 @@
630 -line 0
631 +LINE ZERO
632 line 1
633 line 2
634 line 3
635 "}]
636 );
637
638 // Case 2: Multiple successive manual edits near each other are merged into one
639 // event with `predicted` set to false.
640 buffer.update(cx, |buffer, cx| {
641 let offset = Point::new(1, 0).to_offset(buffer);
642 let end = Point::new(1, 6).to_offset(buffer);
643 buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
644 });
645
646 let events = ep_store.update(cx, |ep_store, cx| {
647 ep_store.edit_history_for_project(&project, cx)
648 });
649 assert_eq!(
650 render_events_with_predicted(&events),
651 vec![indoc! {"
652 manual
653 @@ -1,5 +1,5 @@
654 -line 0
655 -line 1
656 +LINE ZERO
657 +LINE ONE
658 line 2
659 line 3
660 line 4
661 "}]
662 );
663
664 // Case 3: Accepted predictions have `predicted` set to true.
665 // Case 5: A manual edit that follows a predicted edit is not merged with the
666 // predicted edit, even if it is nearby.
667 ep_store.update(cx, |ep_store, cx| {
668 buffer.update(cx, |buffer, cx| {
669 let offset = Point::new(2, 0).to_offset(buffer);
670 let end = Point::new(2, 6).to_offset(buffer);
671 buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
672 });
673 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
674 });
675
676 let events = ep_store.update(cx, |ep_store, cx| {
677 ep_store.edit_history_for_project(&project, cx)
678 });
679 assert_eq!(
680 render_events_with_predicted(&events),
681 vec![
682 indoc! {"
683 manual
684 @@ -1,5 +1,5 @@
685 -line 0
686 -line 1
687 +LINE ZERO
688 +LINE ONE
689 line 2
690 line 3
691 line 4
692 "},
693 indoc! {"
694 predicted
695 @@ -1,6 +1,6 @@
696 LINE ZERO
697 LINE ONE
698 -line 2
699 +LINE TWO
700 line 3
701 line 4
702 line 5
703 "}
704 ]
705 );
706
707 // Case 4: Multiple successive accepted predictions near each other are merged
708 // into one event with `predicted` set to true.
709 ep_store.update(cx, |ep_store, cx| {
710 buffer.update(cx, |buffer, cx| {
711 let offset = Point::new(3, 0).to_offset(buffer);
712 let end = Point::new(3, 6).to_offset(buffer);
713 buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
714 });
715 ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
716 });
717
718 let events = ep_store.update(cx, |ep_store, cx| {
719 ep_store.edit_history_for_project(&project, cx)
720 });
721 assert_eq!(
722 render_events_with_predicted(&events),
723 vec![
724 indoc! {"
725 manual
726 @@ -1,5 +1,5 @@
727 -line 0
728 -line 1
729 +LINE ZERO
730 +LINE ONE
731 line 2
732 line 3
733 line 4
734 "},
735 indoc! {"
736 predicted
737 @@ -1,7 +1,7 @@
738 LINE ZERO
739 LINE ONE
740 -line 2
741 -line 3
742 +LINE TWO
743 +LINE THREE
744 line 4
745 line 5
746 line 6
747 "}
748 ]
749 );
750
751 // Case 5 (continued): A manual edit that follows a predicted edit is not merged
752 // with the predicted edit, even if it is nearby.
753 buffer.update(cx, |buffer, cx| {
754 let offset = Point::new(4, 0).to_offset(buffer);
755 let end = Point::new(4, 6).to_offset(buffer);
756 buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
757 });
758
759 let events = ep_store.update(cx, |ep_store, cx| {
760 ep_store.edit_history_for_project(&project, cx)
761 });
762 assert_eq!(
763 render_events_with_predicted(&events),
764 vec![
765 indoc! {"
766 manual
767 @@ -1,5 +1,5 @@
768 -line 0
769 -line 1
770 +LINE ZERO
771 +LINE ONE
772 line 2
773 line 3
774 line 4
775 "},
776 indoc! {"
777 predicted
778 @@ -1,7 +1,7 @@
779 LINE ZERO
780 LINE ONE
781 -line 2
782 -line 3
783 +LINE TWO
784 +LINE THREE
785 line 4
786 line 5
787 line 6
788 "},
789 indoc! {"
790 manual
791 @@ -2,7 +2,7 @@
792 LINE ONE
793 LINE TWO
794 LINE THREE
795 -line 4
796 +LINE FOUR
797 line 5
798 line 6
799 line 7
800 "}
801 ]
802 );
803
804 // Case 6: If we then perform a manual edit at a *different* location (more than
805 // 8 lines away), then the edits at the prior location can be merged with each
806 // other, even if some are predicted and some are not. `predicted` means all
807 // constituent edits were predicted.
808 buffer.update(cx, |buffer, cx| {
809 let offset = Point::new(14, 0).to_offset(buffer);
810 let end = Point::new(14, 7).to_offset(buffer);
811 buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
812 });
813
814 let events = ep_store.update(cx, |ep_store, cx| {
815 ep_store.edit_history_for_project(&project, cx)
816 });
817 assert_eq!(
818 render_events_with_predicted(&events),
819 vec![
820 indoc! {"
821 manual
822 @@ -1,8 +1,8 @@
823 -line 0
824 -line 1
825 -line 2
826 -line 3
827 -line 4
828 +LINE ZERO
829 +LINE ONE
830 +LINE TWO
831 +LINE THREE
832 +LINE FOUR
833 line 5
834 line 6
835 line 7
836 "},
837 indoc! {"
838 manual
839 @@ -12,4 +12,4 @@
840 line 11
841 line 12
842 line 13
843 -line 14
844 +LINE FOURTEEN
845 "}
846 ]
847 );
848}
849
850#[gpui::test]
851async fn test_empty_prediction(cx: &mut TestAppContext) {
852 let (ep_store, mut requests) = init_test_with_fake_client(cx);
853 let fs = FakeFs::new(cx.executor());
854 fs.insert_tree(
855 "/root",
856 json!({
857 "foo.md": "Hello!\nHow\nBye\n"
858 }),
859 )
860 .await;
861 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
862
863 let buffer = project
864 .update(cx, |project, cx| {
865 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
866 project.open_buffer(path, cx)
867 })
868 .await
869 .unwrap();
870 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
871 let position = snapshot.anchor_before(language::Point::new(1, 3));
872
873 ep_store.update(cx, |ep_store, cx| {
874 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
875 });
876
877 let (request, respond_tx) = requests.predict.next().await.unwrap();
878 let response = model_response(&request, "");
879 let id = response.request_id.clone();
880 respond_tx.send(response).unwrap();
881
882 cx.run_until_parked();
883
884 ep_store.update(cx, |ep_store, cx| {
885 assert!(
886 ep_store
887 .prediction_at(&buffer, None, &project, cx)
888 .is_none()
889 );
890 });
891
892 // prediction is reported as rejected
893 let (reject_request, _) = requests.reject.next().await.unwrap();
894
895 assert_eq!(
896 &reject_request.rejections,
897 &[EditPredictionRejection {
898 request_id: id,
899 reason: EditPredictionRejectReason::Empty,
900 was_shown: false,
901 model_version: None,
902 }]
903 );
904}
905
906#[gpui::test]
907async fn test_interpolated_empty(cx: &mut TestAppContext) {
908 let (ep_store, mut requests) = init_test_with_fake_client(cx);
909 let fs = FakeFs::new(cx.executor());
910 fs.insert_tree(
911 "/root",
912 json!({
913 "foo.md": "Hello!\nHow\nBye\n"
914 }),
915 )
916 .await;
917 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
918
919 let buffer = project
920 .update(cx, |project, cx| {
921 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
922 project.open_buffer(path, cx)
923 })
924 .await
925 .unwrap();
926 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
927 let position = snapshot.anchor_before(language::Point::new(1, 3));
928
929 ep_store.update(cx, |ep_store, cx| {
930 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
931 });
932
933 let (request, respond_tx) = requests.predict.next().await.unwrap();
934
935 buffer.update(cx, |buffer, cx| {
936 buffer.set_text("Hello!\nHow are you?\nBye", cx);
937 });
938
939 let response = model_response(&request, SIMPLE_DIFF);
940 let id = response.request_id.clone();
941 respond_tx.send(response).unwrap();
942
943 cx.run_until_parked();
944
945 ep_store.update(cx, |ep_store, cx| {
946 assert!(
947 ep_store
948 .prediction_at(&buffer, None, &project, cx)
949 .is_none()
950 );
951 });
952
953 // prediction is reported as rejected
954 let (reject_request, _) = requests.reject.next().await.unwrap();
955
956 assert_eq!(
957 &reject_request.rejections,
958 &[EditPredictionRejection {
959 request_id: id,
960 reason: EditPredictionRejectReason::InterpolatedEmpty,
961 was_shown: false,
962 model_version: None,
963 }]
964 );
965}
966
967const SIMPLE_DIFF: &str = indoc! { r"
968 --- a/root/foo.md
969 +++ b/root/foo.md
970 @@ ... @@
971 Hello!
972 -How
973 +How are you?
974 Bye
975"};
976
977#[gpui::test]
978async fn test_replace_current(cx: &mut TestAppContext) {
979 let (ep_store, mut requests) = init_test_with_fake_client(cx);
980 let fs = FakeFs::new(cx.executor());
981 fs.insert_tree(
982 "/root",
983 json!({
984 "foo.md": "Hello!\nHow\nBye\n"
985 }),
986 )
987 .await;
988 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
989
990 let buffer = project
991 .update(cx, |project, cx| {
992 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
993 project.open_buffer(path, cx)
994 })
995 .await
996 .unwrap();
997 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
998 let position = snapshot.anchor_before(language::Point::new(1, 3));
999
1000 ep_store.update(cx, |ep_store, cx| {
1001 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1002 });
1003
1004 let (request, respond_tx) = requests.predict.next().await.unwrap();
1005 let first_response = model_response(&request, SIMPLE_DIFF);
1006 let first_id = first_response.request_id.clone();
1007 respond_tx.send(first_response).unwrap();
1008
1009 cx.run_until_parked();
1010
1011 ep_store.update(cx, |ep_store, cx| {
1012 assert_eq!(
1013 ep_store
1014 .prediction_at(&buffer, None, &project, cx)
1015 .unwrap()
1016 .id
1017 .0,
1018 first_id
1019 );
1020 });
1021
1022 // a second request is triggered
1023 ep_store.update(cx, |ep_store, cx| {
1024 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1025 });
1026
1027 let (request, respond_tx) = requests.predict.next().await.unwrap();
1028 let second_response = model_response(&request, SIMPLE_DIFF);
1029 let second_id = second_response.request_id.clone();
1030 respond_tx.send(second_response).unwrap();
1031
1032 cx.run_until_parked();
1033
1034 ep_store.update(cx, |ep_store, cx| {
1035 // second replaces first
1036 assert_eq!(
1037 ep_store
1038 .prediction_at(&buffer, None, &project, cx)
1039 .unwrap()
1040 .id
1041 .0,
1042 second_id
1043 );
1044 });
1045
1046 // first is reported as replaced
1047 let (reject_request, _) = requests.reject.next().await.unwrap();
1048
1049 assert_eq!(
1050 &reject_request.rejections,
1051 &[EditPredictionRejection {
1052 request_id: first_id,
1053 reason: EditPredictionRejectReason::Replaced,
1054 was_shown: false,
1055 model_version: None,
1056 }]
1057 );
1058}
1059
1060#[gpui::test]
1061async fn test_current_preferred(cx: &mut TestAppContext) {
1062 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1063 let fs = FakeFs::new(cx.executor());
1064 fs.insert_tree(
1065 "/root",
1066 json!({
1067 "foo.md": "Hello!\nHow\nBye\n"
1068 }),
1069 )
1070 .await;
1071 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1072
1073 let buffer = project
1074 .update(cx, |project, cx| {
1075 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1076 project.open_buffer(path, cx)
1077 })
1078 .await
1079 .unwrap();
1080 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1081 let position = snapshot.anchor_before(language::Point::new(1, 3));
1082
1083 ep_store.update(cx, |ep_store, cx| {
1084 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1085 });
1086
1087 let (request, respond_tx) = requests.predict.next().await.unwrap();
1088 let first_response = model_response(&request, SIMPLE_DIFF);
1089 let first_id = first_response.request_id.clone();
1090 respond_tx.send(first_response).unwrap();
1091
1092 cx.run_until_parked();
1093
1094 ep_store.update(cx, |ep_store, cx| {
1095 assert_eq!(
1096 ep_store
1097 .prediction_at(&buffer, None, &project, cx)
1098 .unwrap()
1099 .id
1100 .0,
1101 first_id
1102 );
1103 });
1104
1105 // a second request is triggered
1106 ep_store.update(cx, |ep_store, cx| {
1107 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1108 });
1109
1110 let (request, respond_tx) = requests.predict.next().await.unwrap();
1111 // worse than current prediction
1112 let second_response = model_response(
1113 &request,
1114 indoc! { r"
1115 --- a/root/foo.md
1116 +++ b/root/foo.md
1117 @@ ... @@
1118 Hello!
1119 -How
1120 +How are
1121 Bye
1122 "},
1123 );
1124 let second_id = second_response.request_id.clone();
1125 respond_tx.send(second_response).unwrap();
1126
1127 cx.run_until_parked();
1128
1129 ep_store.update(cx, |ep_store, cx| {
1130 // first is preferred over second
1131 assert_eq!(
1132 ep_store
1133 .prediction_at(&buffer, None, &project, cx)
1134 .unwrap()
1135 .id
1136 .0,
1137 first_id
1138 );
1139 });
1140
1141 // second is reported as rejected
1142 let (reject_request, _) = requests.reject.next().await.unwrap();
1143
1144 assert_eq!(
1145 &reject_request.rejections,
1146 &[EditPredictionRejection {
1147 request_id: second_id,
1148 reason: EditPredictionRejectReason::CurrentPreferred,
1149 was_shown: false,
1150 model_version: None,
1151 }]
1152 );
1153}
1154
1155#[gpui::test]
1156async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1157 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1158 let fs = FakeFs::new(cx.executor());
1159 fs.insert_tree(
1160 "/root",
1161 json!({
1162 "foo.md": "Hello!\nHow\nBye\n"
1163 }),
1164 )
1165 .await;
1166 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1167
1168 let buffer = project
1169 .update(cx, |project, cx| {
1170 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1171 project.open_buffer(path, cx)
1172 })
1173 .await
1174 .unwrap();
1175 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1176 let position = snapshot.anchor_before(language::Point::new(1, 3));
1177
1178 // start two refresh tasks
1179 ep_store.update(cx, |ep_store, cx| {
1180 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1181 });
1182
1183 let (request1, respond_first) = requests.predict.next().await.unwrap();
1184
1185 ep_store.update(cx, |ep_store, cx| {
1186 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1187 });
1188
1189 let (request, respond_second) = requests.predict.next().await.unwrap();
1190
1191 // wait for throttle
1192 cx.run_until_parked();
1193
1194 // second responds first
1195 let second_response = model_response(&request, SIMPLE_DIFF);
1196 let second_id = second_response.request_id.clone();
1197 respond_second.send(second_response).unwrap();
1198
1199 cx.run_until_parked();
1200
1201 ep_store.update(cx, |ep_store, cx| {
1202 // current prediction is second
1203 assert_eq!(
1204 ep_store
1205 .prediction_at(&buffer, None, &project, cx)
1206 .unwrap()
1207 .id
1208 .0,
1209 second_id
1210 );
1211 });
1212
1213 let first_response = model_response(&request1, SIMPLE_DIFF);
1214 let first_id = first_response.request_id.clone();
1215 respond_first.send(first_response).unwrap();
1216
1217 cx.run_until_parked();
1218
1219 ep_store.update(cx, |ep_store, cx| {
1220 // current prediction is still second, since first was cancelled
1221 assert_eq!(
1222 ep_store
1223 .prediction_at(&buffer, None, &project, cx)
1224 .unwrap()
1225 .id
1226 .0,
1227 second_id
1228 );
1229 });
1230
1231 // first is reported as rejected
1232 let (reject_request, _) = requests.reject.next().await.unwrap();
1233
1234 cx.run_until_parked();
1235
1236 assert_eq!(
1237 &reject_request.rejections,
1238 &[EditPredictionRejection {
1239 request_id: first_id,
1240 reason: EditPredictionRejectReason::Canceled,
1241 was_shown: false,
1242 model_version: None,
1243 }]
1244 );
1245}
1246
1247#[gpui::test]
1248async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1249 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1250 let fs = FakeFs::new(cx.executor());
1251 fs.insert_tree(
1252 "/root",
1253 json!({
1254 "foo.md": "Hello!\nHow\nBye\n"
1255 }),
1256 )
1257 .await;
1258 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1259
1260 let buffer = project
1261 .update(cx, |project, cx| {
1262 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1263 project.open_buffer(path, cx)
1264 })
1265 .await
1266 .unwrap();
1267 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1268 let position = snapshot.anchor_before(language::Point::new(1, 3));
1269
1270 // start two refresh tasks
1271 ep_store.update(cx, |ep_store, cx| {
1272 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1273 });
1274
1275 let (request1, respond_first) = requests.predict.next().await.unwrap();
1276
1277 ep_store.update(cx, |ep_store, cx| {
1278 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1279 });
1280
1281 let (request2, respond_second) = requests.predict.next().await.unwrap();
1282
1283 // wait for throttle, so requests are sent
1284 cx.run_until_parked();
1285
1286 ep_store.update(cx, |ep_store, cx| {
1287 // start a third request
1288 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1289
1290 // 2 are pending, so 2nd is cancelled
1291 assert_eq!(
1292 ep_store
1293 .get_or_init_project(&project, cx)
1294 .cancelled_predictions
1295 .iter()
1296 .copied()
1297 .collect::<Vec<_>>(),
1298 [1]
1299 );
1300 });
1301
1302 // wait for throttle
1303 cx.run_until_parked();
1304
1305 let (request3, respond_third) = requests.predict.next().await.unwrap();
1306
1307 let first_response = model_response(&request1, SIMPLE_DIFF);
1308 let first_id = first_response.request_id.clone();
1309 respond_first.send(first_response).unwrap();
1310
1311 cx.run_until_parked();
1312
1313 ep_store.update(cx, |ep_store, cx| {
1314 // current prediction is first
1315 assert_eq!(
1316 ep_store
1317 .prediction_at(&buffer, None, &project, cx)
1318 .unwrap()
1319 .id
1320 .0,
1321 first_id
1322 );
1323 });
1324
1325 let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1326 let cancelled_id = cancelled_response.request_id.clone();
1327 respond_second.send(cancelled_response).unwrap();
1328
1329 cx.run_until_parked();
1330
1331 ep_store.update(cx, |ep_store, cx| {
1332 // current prediction is still first, since second was cancelled
1333 assert_eq!(
1334 ep_store
1335 .prediction_at(&buffer, None, &project, cx)
1336 .unwrap()
1337 .id
1338 .0,
1339 first_id
1340 );
1341 });
1342
1343 let third_response = model_response(&request3, SIMPLE_DIFF);
1344 let third_response_id = third_response.request_id.clone();
1345 respond_third.send(third_response).unwrap();
1346
1347 cx.run_until_parked();
1348
1349 ep_store.update(cx, |ep_store, cx| {
1350 // third completes and replaces first
1351 assert_eq!(
1352 ep_store
1353 .prediction_at(&buffer, None, &project, cx)
1354 .unwrap()
1355 .id
1356 .0,
1357 third_response_id
1358 );
1359 });
1360
1361 // second is reported as rejected
1362 let (reject_request, _) = requests.reject.next().await.unwrap();
1363
1364 cx.run_until_parked();
1365
1366 assert_eq!(
1367 &reject_request.rejections,
1368 &[
1369 EditPredictionRejection {
1370 request_id: cancelled_id,
1371 reason: EditPredictionRejectReason::Canceled,
1372 was_shown: false,
1373 model_version: None,
1374 },
1375 EditPredictionRejection {
1376 request_id: first_id,
1377 reason: EditPredictionRejectReason::Replaced,
1378 was_shown: false,
1379 model_version: None,
1380 }
1381 ]
1382 );
1383}
1384
1385#[gpui::test]
1386async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
1387 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1388
1389 let fs = FakeFs::new(cx.executor());
1390 fs.insert_tree(
1391 "/root",
1392 json!({
1393 "foo.md": "Hello!\nHow\nBye\n",
1394 "bar.md": "Hola!\nComo\nAdios\n"
1395 }),
1396 )
1397 .await;
1398 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1399
1400 let buffer = project
1401 .update(cx, |project, cx| {
1402 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1403 project.set_active_path(Some(path.clone()), cx);
1404 project.open_buffer(path, cx)
1405 })
1406 .await
1407 .unwrap();
1408 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1409 let position = snapshot.anchor_before(language::Point::new(1, 3));
1410
1411 ep_store.update(cx, |ep_store, cx| {
1412 ep_store.register_project(&project, cx);
1413 ep_store.register_buffer(&buffer, &project, cx);
1414 });
1415
1416 // First edit request - no prior edit, so not throttled.
1417 ep_store.update(cx, |ep_store, cx| {
1418 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1419 });
1420 let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
1421 edit_response_tx.send(empty_response()).unwrap();
1422 cx.run_until_parked();
1423
1424 let diagnostic = lsp::Diagnostic {
1425 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1426 severity: Some(lsp::DiagnosticSeverity::ERROR),
1427 message: "Sentence is incomplete".to_string(),
1428 ..Default::default()
1429 };
1430
1431 // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
1432 project.update(cx, |project, cx| {
1433 project.lsp_store().update(cx, |lsp_store, cx| {
1434 lsp_store
1435 .update_diagnostics(
1436 LanguageServerId(0),
1437 lsp::PublishDiagnosticsParams {
1438 uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
1439 diagnostics: vec![diagnostic],
1440 version: None,
1441 },
1442 None,
1443 language::DiagnosticSourceKind::Pushed,
1444 &[],
1445 cx,
1446 )
1447 .unwrap();
1448 });
1449 });
1450 let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
1451 jump_response_tx.send(empty_response()).unwrap();
1452 cx.run_until_parked();
1453
1454 // Second edit request - should be throttled by the first edit.
1455 ep_store.update(cx, |ep_store, cx| {
1456 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1457 });
1458 assert_no_predict_request_ready(&mut requests.predict);
1459
1460 // Second jump request - should be throttled by the first jump.
1461 ep_store.update(cx, |ep_store, cx| {
1462 ep_store.refresh_prediction_from_diagnostics(
1463 project.clone(),
1464 DiagnosticSearchScope::Global,
1465 cx,
1466 );
1467 });
1468 assert_no_predict_request_ready(&mut requests.predict);
1469
1470 // Wait for both throttles to expire.
1471 cx.background_executor
1472 .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
1473 cx.background_executor.run_until_parked();
1474 cx.run_until_parked();
1475
1476 // Both requests should now go through.
1477 let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
1478 response_tx_1.send(empty_response()).unwrap();
1479 cx.run_until_parked();
1480
1481 let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
1482 response_tx_2.send(empty_response()).unwrap();
1483 cx.run_until_parked();
1484}
1485
1486#[gpui::test]
1487async fn test_rejections_flushing(cx: &mut TestAppContext) {
1488 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1489
1490 ep_store.update(cx, |ep_store, cx| {
1491 ep_store.reject_prediction(
1492 EditPredictionId("test-1".into()),
1493 EditPredictionRejectReason::Discarded,
1494 false,
1495 None,
1496 cx,
1497 );
1498 ep_store.reject_prediction(
1499 EditPredictionId("test-2".into()),
1500 EditPredictionRejectReason::Canceled,
1501 true,
1502 None,
1503 cx,
1504 );
1505 });
1506
1507 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1508 cx.run_until_parked();
1509
1510 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1511 respond_tx.send(()).unwrap();
1512
1513 // batched
1514 assert_eq!(reject_request.rejections.len(), 2);
1515 assert_eq!(
1516 reject_request.rejections[0],
1517 EditPredictionRejection {
1518 request_id: "test-1".to_string(),
1519 reason: EditPredictionRejectReason::Discarded,
1520 was_shown: false,
1521 model_version: None,
1522 }
1523 );
1524 assert_eq!(
1525 reject_request.rejections[1],
1526 EditPredictionRejection {
1527 request_id: "test-2".to_string(),
1528 reason: EditPredictionRejectReason::Canceled,
1529 was_shown: true,
1530 model_version: None,
1531 }
1532 );
1533
1534 // Reaching batch size limit sends without debounce
1535 ep_store.update(cx, |ep_store, cx| {
1536 for i in 0..70 {
1537 ep_store.reject_prediction(
1538 EditPredictionId(format!("batch-{}", i).into()),
1539 EditPredictionRejectReason::Discarded,
1540 false,
1541 None,
1542 cx,
1543 );
1544 }
1545 });
1546
1547 // First MAX/2 items are sent immediately
1548 cx.run_until_parked();
1549 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1550 respond_tx.send(()).unwrap();
1551
1552 assert_eq!(reject_request.rejections.len(), 50);
1553 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1554 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1555
1556 // Remaining items are debounced with the next batch
1557 cx.executor().advance_clock(Duration::from_secs(15));
1558 cx.run_until_parked();
1559
1560 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1561 respond_tx.send(()).unwrap();
1562
1563 assert_eq!(reject_request.rejections.len(), 20);
1564 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1565 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1566
1567 // Request failure
1568 ep_store.update(cx, |ep_store, cx| {
1569 ep_store.reject_prediction(
1570 EditPredictionId("retry-1".into()),
1571 EditPredictionRejectReason::Discarded,
1572 false,
1573 None,
1574 cx,
1575 );
1576 });
1577
1578 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1579 cx.run_until_parked();
1580
1581 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1582 assert_eq!(reject_request.rejections.len(), 1);
1583 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1584 // Simulate failure
1585 drop(_respond_tx);
1586
1587 // Add another rejection
1588 ep_store.update(cx, |ep_store, cx| {
1589 ep_store.reject_prediction(
1590 EditPredictionId("retry-2".into()),
1591 EditPredictionRejectReason::Discarded,
1592 false,
1593 None,
1594 cx,
1595 );
1596 });
1597
1598 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1599 cx.run_until_parked();
1600
1601 // Retry should include both the failed item and the new one
1602 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1603 respond_tx.send(()).unwrap();
1604
1605 assert_eq!(reject_request.rejections.len(), 2);
1606 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1607 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1608}
1609
1610// Skipped until we start including diagnostics in prompt
1611// #[gpui::test]
1612// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1613// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1614// let fs = FakeFs::new(cx.executor());
1615// fs.insert_tree(
1616// "/root",
1617// json!({
1618// "foo.md": "Hello!\nBye"
1619// }),
1620// )
1621// .await;
1622// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1623
1624// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1625// let diagnostic = lsp::Diagnostic {
1626// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1627// severity: Some(lsp::DiagnosticSeverity::ERROR),
1628// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1629// ..Default::default()
1630// };
1631
1632// project.update(cx, |project, cx| {
1633// project.lsp_store().update(cx, |lsp_store, cx| {
1634// // Create some diagnostics
1635// lsp_store
1636// .update_diagnostics(
1637// LanguageServerId(0),
1638// lsp::PublishDiagnosticsParams {
1639// uri: path_to_buffer_uri.clone(),
1640// diagnostics: vec![diagnostic],
1641// version: None,
1642// },
1643// None,
1644// language::DiagnosticSourceKind::Pushed,
1645// &[],
1646// cx,
1647// )
1648// .unwrap();
1649// });
1650// });
1651
1652// let buffer = project
1653// .update(cx, |project, cx| {
1654// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1655// project.open_buffer(path, cx)
1656// })
1657// .await
1658// .unwrap();
1659
1660// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1661// let position = snapshot.anchor_before(language::Point::new(0, 0));
1662
1663// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1664// ep_store.request_prediction(&project, &buffer, position, cx)
1665// });
1666
1667// let (request, _respond_tx) = req_rx.next().await.unwrap();
1668
1669// assert_eq!(request.diagnostic_groups.len(), 1);
1670// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1671// .unwrap();
1672// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1673// assert_eq!(
1674// value,
1675// json!({
1676// "entries": [{
1677// "range": {
1678// "start": 8,
1679// "end": 10
1680// },
1681// "diagnostic": {
1682// "source": null,
1683// "code": null,
1684// "code_description": null,
1685// "severity": 1,
1686// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1687// "markdown": null,
1688// "group_id": 0,
1689// "is_primary": true,
1690// "is_disk_based": false,
1691// "is_unnecessary": false,
1692// "source_kind": "Pushed",
1693// "data": null,
1694// "underline": true
1695// }
1696// }],
1697// "primary_ix": 0
1698// })
1699// );
1700// }
1701
1702// Generate a model response that would apply the given diff to the active file.
1703fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1704 let editable_range = request
1705 .input
1706 .excerpt_ranges
1707 .as_ref()
1708 .map(|r| zeta_prompt::excerpt_range_for_format(Default::default(), r).1)
1709 .unwrap_or(request.input.editable_range_in_excerpt.clone());
1710 let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string();
1711 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1712
1713 PredictEditsV3Response {
1714 request_id: Uuid::new_v4().to_string(),
1715 editable_range,
1716 output: new_excerpt,
1717 model_version: None,
1718 }
1719}
1720
1721fn empty_response() -> PredictEditsV3Response {
1722 PredictEditsV3Response {
1723 request_id: Uuid::new_v4().to_string(),
1724 editable_range: 0..0,
1725 output: String::new(),
1726 model_version: None,
1727 }
1728}
1729
1730fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1731 zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1732}
1733
1734fn assert_no_predict_request_ready(
1735 requests: &mut mpsc::UnboundedReceiver<(
1736 PredictEditsV3Request,
1737 oneshot::Sender<PredictEditsV3Response>,
1738 )>,
1739) {
1740 if requests.next().now_or_never().flatten().is_some() {
1741 panic!("Unexpected prediction request while throttled.");
1742 }
1743}
1744
1745struct RequestChannels {
1746 predict: mpsc::UnboundedReceiver<(
1747 PredictEditsV3Request,
1748 oneshot::Sender<PredictEditsV3Response>,
1749 )>,
1750 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1751}
1752
1753fn init_test_with_fake_client(
1754 cx: &mut TestAppContext,
1755) -> (Entity<EditPredictionStore>, RequestChannels) {
1756 cx.update(move |cx| {
1757 let settings_store = SettingsStore::test(cx);
1758 cx.set_global(settings_store);
1759 zlog::init_test();
1760
1761 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1762 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1763
1764 let http_client = FakeHttpClient::create({
1765 move |req| {
1766 let uri = req.uri().path().to_string();
1767 let mut body = req.into_body();
1768 let predict_req_tx = predict_req_tx.clone();
1769 let reject_req_tx = reject_req_tx.clone();
1770 async move {
1771 let resp = match uri.as_str() {
1772 "/client/llm_tokens" => serde_json::to_string(&json!({
1773 "token": "test"
1774 }))
1775 .unwrap(),
1776 "/predict_edits/v3" => {
1777 let mut buf = Vec::new();
1778 body.read_to_end(&mut buf).await.ok();
1779 let decompressed = zstd::decode_all(&buf[..]).unwrap();
1780 let req = serde_json::from_slice(&decompressed).unwrap();
1781
1782 let (res_tx, res_rx) = oneshot::channel();
1783 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1784 serde_json::to_string(&res_rx.await?).unwrap()
1785 }
1786 "/predict_edits/reject" => {
1787 let mut buf = Vec::new();
1788 body.read_to_end(&mut buf).await.ok();
1789 let req = serde_json::from_slice(&buf).unwrap();
1790
1791 let (res_tx, res_rx) = oneshot::channel();
1792 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1793 serde_json::to_string(&res_rx.await?).unwrap()
1794 }
1795 _ => {
1796 panic!("Unexpected path: {}", uri)
1797 }
1798 };
1799
1800 Ok(Response::builder().body(resp.into()).unwrap())
1801 }
1802 }
1803 });
1804
1805 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1806 client.cloud_client().set_credentials(1, "test".into());
1807
1808 language_model::init(client.clone(), cx);
1809
1810 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1811 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1812
1813 (
1814 ep_store,
1815 RequestChannels {
1816 predict: predict_req_rx,
1817 reject: reject_req_rx,
1818 },
1819 )
1820 })
1821}
1822
1823#[gpui::test]
1824async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1825 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1826 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1827 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1828 });
1829
1830 let edit_preview = cx
1831 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1832 .await;
1833
1834 let prediction = EditPrediction {
1835 edits,
1836 cursor_position: None,
1837 edit_preview,
1838 buffer: buffer.clone(),
1839 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1840 id: EditPredictionId("the-id".into()),
1841 inputs: ZetaPromptInput {
1842 events: Default::default(),
1843 related_files: Default::default(),
1844 cursor_path: Path::new("").into(),
1845 cursor_excerpt: "".into(),
1846 editable_range_in_excerpt: 0..0,
1847 cursor_offset_in_excerpt: 0,
1848 excerpt_start_row: None,
1849 excerpt_ranges: None,
1850 preferred_model: None,
1851 in_open_source_repo: false,
1852 can_collect_data: false,
1853 },
1854 buffer_snapshotted_at: Instant::now(),
1855 response_received_at: Instant::now(),
1856 model_version: None,
1857 };
1858
1859 cx.update(|cx| {
1860 assert_eq!(
1861 from_completion_edits(
1862 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1863 &buffer,
1864 cx
1865 ),
1866 vec![(2..5, "REM".into()), (9..11, "".into())]
1867 );
1868
1869 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1870 assert_eq!(
1871 from_completion_edits(
1872 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1873 &buffer,
1874 cx
1875 ),
1876 vec![(2..2, "REM".into()), (6..8, "".into())]
1877 );
1878
1879 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1880 assert_eq!(
1881 from_completion_edits(
1882 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1883 &buffer,
1884 cx
1885 ),
1886 vec![(2..5, "REM".into()), (9..11, "".into())]
1887 );
1888
1889 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1890 assert_eq!(
1891 from_completion_edits(
1892 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1893 &buffer,
1894 cx
1895 ),
1896 vec![(3..3, "EM".into()), (7..9, "".into())]
1897 );
1898
1899 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1900 assert_eq!(
1901 from_completion_edits(
1902 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1903 &buffer,
1904 cx
1905 ),
1906 vec![(4..4, "M".into()), (8..10, "".into())]
1907 );
1908
1909 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1910 assert_eq!(
1911 from_completion_edits(
1912 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1913 &buffer,
1914 cx
1915 ),
1916 vec![(9..11, "".into())]
1917 );
1918
1919 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1920 assert_eq!(
1921 from_completion_edits(
1922 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1923 &buffer,
1924 cx
1925 ),
1926 vec![(4..4, "M".into()), (8..10, "".into())]
1927 );
1928
1929 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1930 assert_eq!(
1931 from_completion_edits(
1932 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1933 &buffer,
1934 cx
1935 ),
1936 vec![(4..4, "M".into())]
1937 );
1938
1939 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1940 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1941 })
1942}
1943
1944#[gpui::test]
1945async fn test_clean_up_diff(cx: &mut TestAppContext) {
1946 init_test(cx);
1947
1948 assert_eq!(
1949 apply_edit_prediction(
1950 indoc! {"
1951 fn main() {
1952 let word_1 = \"lorem\";
1953 let range = word.len()..word.len();
1954 }
1955 "},
1956 indoc! {"
1957 fn main() {
1958 let word_1 = \"lorem\";
1959 let range = word_1.len()..word_1.len();
1960 }
1961 "},
1962 cx,
1963 )
1964 .await,
1965 indoc! {"
1966 fn main() {
1967 let word_1 = \"lorem\";
1968 let range = word_1.len()..word_1.len();
1969 }
1970 "},
1971 );
1972
1973 assert_eq!(
1974 apply_edit_prediction(
1975 indoc! {"
1976 fn main() {
1977 let story = \"the quick\"
1978 }
1979 "},
1980 indoc! {"
1981 fn main() {
1982 let story = \"the quick brown fox jumps over the lazy dog\";
1983 }
1984 "},
1985 cx,
1986 )
1987 .await,
1988 indoc! {"
1989 fn main() {
1990 let story = \"the quick brown fox jumps over the lazy dog\";
1991 }
1992 "},
1993 );
1994}
1995
1996#[gpui::test]
1997async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1998 init_test(cx);
1999
2000 let buffer_content = "lorem\n";
2001 let completion_response = "lorem\nipsum\n";
2002
2003 assert_eq!(
2004 apply_edit_prediction(buffer_content, completion_response, cx).await,
2005 "lorem\nipsum\n"
2006 );
2007}
2008
2009#[gpui::test]
2010async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
2011 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
2012 // When the buffer ends without a trailing newline, but the model returns output
2013 // with a trailing newline, zeta2 should normalize both sides before diffing
2014 // so no spurious newline is inserted.
2015 let (ep_store, mut requests) = init_test_with_fake_client(cx);
2016 let fs = FakeFs::new(cx.executor());
2017
2018 // Single line buffer with no trailing newline
2019 fs.insert_tree(
2020 "/root",
2021 json!({
2022 "foo.txt": "hello"
2023 }),
2024 )
2025 .await;
2026 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2027
2028 let buffer = project
2029 .update(cx, |project, cx| {
2030 let path = project
2031 .find_project_path(path!("root/foo.txt"), cx)
2032 .unwrap();
2033 project.open_buffer(path, cx)
2034 })
2035 .await
2036 .unwrap();
2037
2038 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2039 let position = snapshot.anchor_before(language::Point::new(0, 5));
2040
2041 ep_store.update(cx, |ep_store, cx| {
2042 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2043 });
2044
2045 let (request, respond_tx) = requests.predict.next().await.unwrap();
2046
2047 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
2048 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
2049 let excerpt_length = request.input.cursor_excerpt.len();
2050 let response = PredictEditsV3Response {
2051 request_id: Uuid::new_v4().to_string(),
2052 output: "hello world\n".to_string(),
2053 editable_range: 0..excerpt_length,
2054 model_version: None,
2055 };
2056 respond_tx.send(response).unwrap();
2057
2058 cx.run_until_parked();
2059
2060 // The prediction should insert " world" without adding a newline
2061 ep_store.update(cx, |ep_store, cx| {
2062 let prediction = ep_store
2063 .prediction_at(&buffer, None, &project, cx)
2064 .expect("should have prediction");
2065 let edits: Vec<_> = prediction
2066 .edits
2067 .iter()
2068 .map(|(range, text)| {
2069 let snapshot = buffer.read(cx).snapshot();
2070 (range.to_offset(&snapshot), text.clone())
2071 })
2072 .collect();
2073 assert_eq!(edits, vec![(5..5, " world".into())]);
2074 });
2075}
2076
2077fn init_test(cx: &mut TestAppContext) {
2078 cx.update(|cx| {
2079 let settings_store = SettingsStore::test(cx);
2080 cx.set_global(settings_store);
2081 });
2082}
2083
2084async fn apply_edit_prediction(
2085 buffer_content: &str,
2086 completion_response: &str,
2087 cx: &mut TestAppContext,
2088) -> String {
2089 let fs = project::FakeFs::new(cx.executor());
2090 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2091 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2092 let (ep_store, response) = make_test_ep_store(&project, cx).await;
2093 *response.lock() = completion_response.to_string();
2094 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2095 buffer.update(cx, |buffer, cx| {
2096 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2097 });
2098 buffer.read_with(cx, |buffer, _| buffer.text())
2099}
2100
2101async fn run_edit_prediction(
2102 buffer: &Entity<Buffer>,
2103 project: &Entity<Project>,
2104 ep_store: &Entity<EditPredictionStore>,
2105 cx: &mut TestAppContext,
2106) -> EditPrediction {
2107 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2108 ep_store.update(cx, |ep_store, cx| {
2109 ep_store.register_buffer(buffer, &project, cx)
2110 });
2111 cx.background_executor.run_until_parked();
2112 let prediction_task = ep_store.update(cx, |ep_store, cx| {
2113 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2114 });
2115 prediction_task.await.unwrap().unwrap().prediction.unwrap()
2116}
2117
2118async fn make_test_ep_store(
2119 project: &Entity<Project>,
2120 cx: &mut TestAppContext,
2121) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
2122 let default_response = "hello world\n".to_string();
2123 let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
2124 let http_client = FakeHttpClient::create({
2125 let completion_response = completion_response.clone();
2126 let mut next_request_id = 0;
2127 move |req| {
2128 let completion_response = completion_response.clone();
2129 let method = req.method().clone();
2130 let uri = req.uri().path().to_string();
2131 let mut body = req.into_body();
2132 async move {
2133 match (method, uri.as_str()) {
2134 (Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2135 .status(200)
2136 .body(
2137 serde_json::to_string(&CreateLlmTokenResponse {
2138 token: LlmToken("the-llm-token".to_string()),
2139 })
2140 .unwrap()
2141 .into(),
2142 )
2143 .unwrap()),
2144 (Method::POST, "/predict_edits/v3") => {
2145 let mut buf = Vec::new();
2146 body.read_to_end(&mut buf).await.ok();
2147 let decompressed = zstd::decode_all(&buf[..]).unwrap();
2148 let req: PredictEditsV3Request =
2149 serde_json::from_slice(&decompressed).unwrap();
2150
2151 next_request_id += 1;
2152 Ok(http_client::Response::builder()
2153 .status(200)
2154 .body(
2155 serde_json::to_string(&PredictEditsV3Response {
2156 request_id: format!("request-{next_request_id}"),
2157 editable_range: 0..req.input.cursor_excerpt.len(),
2158 output: completion_response.lock().clone(),
2159 model_version: None,
2160 })
2161 .unwrap()
2162 .into(),
2163 )
2164 .unwrap())
2165 }
2166 _ => Ok(http_client::Response::builder()
2167 .status(404)
2168 .body("Not Found".to_string().into())
2169 .unwrap()),
2170 }
2171 }
2172 }
2173 });
2174
2175 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2176 cx.update(|cx| {
2177 RefreshLlmTokenListener::register(client.clone(), cx);
2178 });
2179 let _server = FakeServer::for_client(42, &client, cx).await;
2180
2181 let ep_store = cx.new(|cx| {
2182 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2183 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2184
2185 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2186 for worktree in worktrees {
2187 let worktree_id = worktree.read(cx).id();
2188 ep_store
2189 .get_or_init_project(project, cx)
2190 .license_detection_watchers
2191 .entry(worktree_id)
2192 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2193 }
2194
2195 ep_store
2196 });
2197
2198 (ep_store, completion_response)
2199}
2200
2201fn to_completion_edits(
2202 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2203 buffer: &Entity<Buffer>,
2204 cx: &App,
2205) -> Vec<(Range<Anchor>, Arc<str>)> {
2206 let buffer = buffer.read(cx);
2207 iterator
2208 .into_iter()
2209 .map(|(range, text)| {
2210 (
2211 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2212 text,
2213 )
2214 })
2215 .collect()
2216}
2217
2218fn from_completion_edits(
2219 editor_edits: &[(Range<Anchor>, Arc<str>)],
2220 buffer: &Entity<Buffer>,
2221 cx: &App,
2222) -> Vec<(Range<usize>, Arc<str>)> {
2223 let buffer = buffer.read(cx);
2224 editor_edits
2225 .iter()
2226 .map(|(range, text)| {
2227 (
2228 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2229 text.clone(),
2230 )
2231 })
2232 .collect()
2233}
2234
2235#[gpui::test]
2236async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2237 init_test(cx);
2238
2239 let fs = FakeFs::new(cx.executor());
2240 fs.insert_tree(
2241 "/project",
2242 serde_json::json!({
2243 "main.rs": "fn main() {\n \n}\n"
2244 }),
2245 )
2246 .await;
2247
2248 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2249
2250 let http_client = FakeHttpClient::create(|_req| async move {
2251 Ok(gpui::http_client::Response::builder()
2252 .status(401)
2253 .body("Unauthorized".into())
2254 .unwrap())
2255 });
2256
2257 let client =
2258 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2259 cx.update(|cx| {
2260 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2261 });
2262
2263 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2264
2265 let buffer = project
2266 .update(cx, |project, cx| {
2267 let path = project
2268 .find_project_path(path!("/project/main.rs"), cx)
2269 .unwrap();
2270 project.open_buffer(path, cx)
2271 })
2272 .await
2273 .unwrap();
2274
2275 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2276 ep_store.update(cx, |ep_store, cx| {
2277 ep_store.register_buffer(&buffer, &project, cx)
2278 });
2279 cx.background_executor.run_until_parked();
2280
2281 let completion_task = ep_store.update(cx, |ep_store, cx| {
2282 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2283 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2284 });
2285
2286 let result = completion_task.await;
2287 assert!(
2288 result.is_err(),
2289 "Without authentication and without custom URL, prediction should fail"
2290 );
2291}
2292
2293#[gpui::test]
2294fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2295 let buffer = cx.new(|cx| {
2296 Buffer::local(
2297 indoc! {"
2298 zero
2299 one
2300 two
2301 three
2302 four
2303 five
2304 six
2305 seven
2306 eight
2307 nine
2308 ten
2309 eleven
2310 twelve
2311 thirteen
2312 fourteen
2313 fifteen
2314 sixteen
2315 seventeen
2316 eighteen
2317 nineteen
2318 twenty
2319 twenty-one
2320 twenty-two
2321 twenty-three
2322 twenty-four
2323 "},
2324 cx,
2325 )
2326 });
2327
2328 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2329
2330 buffer.update(cx, |buffer, cx| {
2331 let point = Point::new(12, 0);
2332 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2333 let point = Point::new(8, 0);
2334 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2335 });
2336
2337 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2338
2339 let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2340
2341 assert_eq!(
2342 diff,
2343 indoc! {"
2344 @@ -6,10 +6,12 @@
2345 five
2346 six
2347 seven
2348 +FIRST INSERTION
2349 eight
2350 nine
2351 ten
2352 eleven
2353 +SECOND INSERTION
2354 twelve
2355 thirteen
2356 fourteen
2357 "}
2358 );
2359}
2360
2361#[gpui::test]
2362async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
2363 fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
2364 let collab_replica = clock::ReplicaId::new(10);
2365 let anchor = buffer.read_with(cx, |buffer, _| {
2366 buffer.snapshot().anchor_before(Point::new(row, 0))
2367 });
2368 let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
2369 id: 1,
2370 start: anchor,
2371 end: anchor,
2372 reversed: false,
2373 goal: SelectionGoal::None,
2374 }]);
2375 buffer.update(cx, |buffer, cx| {
2376 buffer.apply_ops(
2377 [Operation::UpdateSelections {
2378 selections,
2379 lamport_timestamp: clock::Lamport {
2380 replica_id: collab_replica,
2381 value: 1,
2382 },
2383 line_mode: false,
2384 cursor_shape: CursorShape::Bar,
2385 }],
2386 cx,
2387 );
2388 });
2389 }
2390
2391 fn publish_diagnostics(
2392 uri_path: &'static str,
2393 rows: &[u32],
2394 project: &Entity<Project>,
2395 cx: &mut TestAppContext,
2396 ) {
2397 let diagnostics: Vec<_> = rows
2398 .iter()
2399 .map(|&row| lsp::Diagnostic {
2400 range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
2401 severity: Some(lsp::DiagnosticSeverity::ERROR),
2402 message: format!("error at row {row}"),
2403 ..Default::default()
2404 })
2405 .collect();
2406 project.update(cx, |project, cx| {
2407 project.lsp_store().update(cx, |lsp_store, cx| {
2408 lsp_store
2409 .update_diagnostics(
2410 LanguageServerId(0),
2411 lsp::PublishDiagnosticsParams {
2412 uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
2413 diagnostics,
2414 version: None,
2415 },
2416 None,
2417 language::DiagnosticSourceKind::Pushed,
2418 &[],
2419 cx,
2420 )
2421 .expect("failed to update diagnostics");
2422 });
2423 });
2424 }
2425
2426 init_test(cx);
2427
2428 let mut lines = String::new();
2429 for i in 0..60 {
2430 lines.push_str(&format!("line {i}\n"));
2431 }
2432
2433 let fs = FakeFs::new(cx.executor());
2434 fs.insert_tree(
2435 "/root",
2436 json!({
2437 "active.txt": lines,
2438 "collab_file.txt": "error here\nsecond line\n",
2439 "free_file.txt": "another error\nsecond line\n",
2440 }),
2441 )
2442 .await;
2443 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2444
2445 let active_buffer = project
2446 .update(cx, |project, cx| {
2447 let path = project
2448 .find_project_path(path!("/root/active.txt"), cx)
2449 .expect("active.txt not found");
2450 project.set_active_path(Some(path.clone()), cx);
2451 project.open_buffer(path, cx)
2452 })
2453 .await
2454 .expect("failed to open active buffer");
2455
2456 set_collaborator_cursor(&active_buffer, 5, cx);
2457
2458 publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
2459
2460 cx.run_until_parked();
2461
2462 let cursor_point = Point::new(25, 0);
2463 let empty_search_range: Range<Point> = Default::default();
2464
2465 let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2466 let result = EditPredictionStore::next_diagnostic_location(
2467 active_buffer.clone(),
2468 &snapshot,
2469 empty_search_range.clone(),
2470 cursor_point,
2471 &project,
2472 &mut cx.to_async(),
2473 )
2474 .await
2475 .expect("next_diagnostic_location failed");
2476
2477 let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
2478 assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
2479 let result_row = result_buffer.read_with(cx, |buffer, _| {
2480 result_anchor.to_point(&buffer.snapshot()).row
2481 });
2482 assert_ne!(
2483 result_row, 3,
2484 "row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
2485 );
2486 assert!(
2487 result_row == 25 || result_row == 50,
2488 "expected row 25 or 50, got {result_row}"
2489 );
2490
2491 let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2492 let near_cursor_point = Point::new(4, 0);
2493 let result_near = EditPredictionStore::next_diagnostic_location(
2494 active_buffer.clone(),
2495 &snapshot_near,
2496 empty_search_range.clone(),
2497 near_cursor_point,
2498 &project,
2499 &mut cx.to_async(),
2500 )
2501 .await
2502 .expect("next_diagnostic_location failed");
2503
2504 let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
2505 let near_row =
2506 active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
2507 assert_eq!(
2508 near_row, 3,
2509 "row 3 should be included when local cursor (row 4) is also near the collaborator"
2510 );
2511
2512 let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2513 let far_cursor_point = Point::new(50, 0);
2514 let result_far = EditPredictionStore::next_diagnostic_location(
2515 active_buffer.clone(),
2516 &snapshot_far,
2517 empty_search_range.clone(),
2518 far_cursor_point,
2519 &project,
2520 &mut cx.to_async(),
2521 )
2522 .await
2523 .expect("next_diagnostic_location failed");
2524
2525 let (_, far_anchor) = result_far.expect("expected a diagnostic location");
2526 let far_row =
2527 active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
2528 assert_eq!(
2529 far_row, 50,
2530 "row 50 is near local cursor (row 50) and far from collaborator, should be picked"
2531 );
2532
2533 publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
2534 publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
2535 cx.run_until_parked();
2536
2537 let collab_buffer = project
2538 .update(cx, |project, cx| {
2539 let path = project
2540 .find_project_path(path!("/root/collab_file.txt"), cx)
2541 .expect("collab_file.txt not found");
2542 project.open_buffer(path, cx)
2543 })
2544 .await
2545 .expect("failed to open collab buffer");
2546
2547 set_collaborator_cursor(&collab_buffer, 0, cx);
2548 cx.run_until_parked();
2549
2550 let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
2551 let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2552 let result_cross = EditPredictionStore::next_diagnostic_location(
2553 active_buffer.clone(),
2554 &snapshot_cross,
2555 no_same_file_search_range,
2556 Point::new(0, 0),
2557 &project,
2558 &mut cx.to_async(),
2559 )
2560 .await
2561 .expect("cross-file next_diagnostic_location failed");
2562
2563 let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
2564 let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
2565 buffer
2566 .file()
2567 .expect("buffer should have a file")
2568 .full_path(cx)
2569 });
2570 assert_eq!(
2571 cross_path,
2572 Path::new(path!("root/free_file.txt")),
2573 "should skip collab_file.txt (has collaborator) and pick free_file.txt"
2574 );
2575}
2576
2577#[ctor::ctor]
2578fn init_logger() {
2579 zlog::init_test();
2580}