1use super::*;
2use crate::udiff::apply_diff_to_string;
3use client::{RefreshLlmTokenListener, UserStore, test::FakeServer};
4use clock::FakeSystemClock;
5use clock::ReplicaId;
6use cloud_api_types::{
7 CreateLlmTokenResponse, LlmToken, Organization, OrganizationConfiguration,
8 OrganizationEditPredictionConfiguration, OrganizationId,
9};
10use cloud_llm_client::{
11 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
12 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
13};
14use db::AppDatabase;
15use settings::EditPredictionDataCollectionChoice;
16
17use futures::{
18 AsyncReadExt, FutureExt, StreamExt,
19 channel::{mpsc, oneshot},
20};
21use gpui::App;
22use gpui::{
23 Entity, TestAppContext,
24 http_client::{FakeHttpClient, Response},
25};
26use indoc::indoc;
27use language::{
28 Anchor, Buffer, Capability, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet,
29 DiagnosticSeverity, Operation, Point, Selection, SelectionGoal,
30};
31
32use lsp::LanguageServerId;
33use parking_lot::Mutex;
34use pretty_assertions::{assert_eq, assert_matches};
35use project::{FakeFs, Project};
36use serde_json::json;
37use settings::SettingsStore;
38use std::{ops::Range, path::Path, sync::Arc, time::Duration};
39use util::{
40 path,
41 test::{TextRangeMarker, marked_text_ranges_by},
42};
43use uuid::Uuid;
44use workspace::{AppState, CollaboratorId, MultiWorkspace};
45use zeta_prompt::ZetaPromptInput;
46
47use crate::{
48 BufferEditPrediction, EDIT_PREDICTION_SETTLED_QUIESCENCE, EditPredictionId,
49 EditPredictionJumpsFeatureFlag, EditPredictionStore, REJECT_REQUEST_DEBOUNCE,
50};
51
52#[gpui::test]
53async fn test_current_state(cx: &mut TestAppContext) {
54 let (ep_store, mut requests) = init_test_with_fake_client(cx);
55 let fs = FakeFs::new(cx.executor());
56 fs.insert_tree(
57 "/root",
58 json!({
59 "1.txt": "Hello!\nHow\nBye\n",
60 "2.txt": "Hola!\nComo\nAdios\n"
61 }),
62 )
63 .await;
64 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
65
66 let buffer1 = project
67 .update(cx, |project, cx| {
68 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
69 project.set_active_path(Some(path.clone()), cx);
70 project.open_buffer(path, cx)
71 })
72 .await
73 .unwrap();
74 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
75 let position = snapshot1.anchor_before(language::Point::new(1, 3));
76
77 ep_store.update(cx, |ep_store, cx| {
78 ep_store.register_project(&project, cx);
79 ep_store.register_buffer(&buffer1, &project, cx);
80 });
81
82 // Prediction for current file
83
84 ep_store.update(cx, |ep_store, cx| {
85 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
86 });
87 let (request, respond_tx) = requests.predict.next().await.unwrap();
88
89 respond_tx
90 .send(model_response(
91 &request,
92 indoc! {r"
93 --- a/root/1.txt
94 +++ b/root/1.txt
95 @@ ... @@
96 Hello!
97 -How
98 +How are you?
99 Bye
100 "},
101 ))
102 .unwrap();
103
104 cx.run_until_parked();
105
106 ep_store.update(cx, |ep_store, cx| {
107 let prediction = ep_store
108 .prediction_at(&buffer1, None, &project, cx)
109 .unwrap();
110 assert_matches!(prediction, BufferEditPrediction::Local { .. });
111 });
112
113 ep_store.update(cx, |ep_store, cx| {
114 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
115 });
116
117 // Prediction for diagnostic in another file
118
119 let diagnostic = lsp::Diagnostic {
120 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
121 severity: Some(lsp::DiagnosticSeverity::ERROR),
122 message: "Sentence is incomplete".to_string(),
123 ..Default::default()
124 };
125
126 project.update(cx, |project, cx| {
127 project.lsp_store().update(cx, |lsp_store, cx| {
128 lsp_store
129 .update_diagnostics(
130 LanguageServerId(0),
131 lsp::PublishDiagnosticsParams {
132 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
133 diagnostics: vec![diagnostic],
134 version: None,
135 },
136 None,
137 language::DiagnosticSourceKind::Pushed,
138 &[],
139 cx,
140 )
141 .unwrap();
142 });
143 });
144
145 let (request, respond_tx) = requests.predict.next().await.unwrap();
146 respond_tx
147 .send(model_response(
148 &request,
149 indoc! {r#"
150 --- a/root/2.txt
151 +++ b/root/2.txt
152 @@ ... @@
153 Hola!
154 -Como
155 +Como estas?
156 Adios
157 "#},
158 ))
159 .unwrap();
160 cx.run_until_parked();
161
162 ep_store.update(cx, |ep_store, cx| {
163 let prediction = ep_store
164 .prediction_at(&buffer1, None, &project, cx)
165 .unwrap();
166 assert_matches!(
167 prediction,
168 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
169 );
170 });
171
172 let buffer2 = project
173 .update(cx, |project, cx| {
174 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
175 project.open_buffer(path, cx)
176 })
177 .await
178 .unwrap();
179
180 ep_store.update(cx, |ep_store, cx| {
181 let prediction = ep_store
182 .prediction_at(&buffer2, None, &project, cx)
183 .unwrap();
184 assert_matches!(prediction, BufferEditPrediction::Local { .. });
185 });
186}
187
188#[gpui::test]
189async fn test_diagnostics_refresh_suppressed_while_following(cx: &mut TestAppContext) {
190 let (ep_store, mut requests) = init_test_with_fake_client(cx);
191
192 cx.update(|cx| {
193 cx.update_flags(
194 false,
195 vec![EditPredictionJumpsFeatureFlag::NAME.to_string()],
196 );
197 });
198
199 let fs = FakeFs::new(cx.executor());
200 fs.insert_tree(
201 "/root",
202 json!({
203 "1.txt": "Hello!\nHow\nBye\n",
204 "2.txt": "Hola!\nComo\nAdios\n"
205 }),
206 )
207 .await;
208 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
209
210 let app_state = cx.update(|cx| {
211 let app_state = AppState::test(cx);
212 AppState::set_global(app_state.clone(), cx);
213 app_state
214 });
215
216 let multi_workspace =
217 cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
218 let workspace = multi_workspace
219 .read_with(cx, |multi_workspace, _| multi_workspace.workspace().clone())
220 .unwrap();
221 cx.update(|cx| {
222 AppState::set_global(workspace.read(cx).app_state().clone(), cx);
223 });
224 let _ = app_state;
225
226 let buffer1 = project
227 .update(cx, |project, cx| {
228 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
229 project.set_active_path(Some(path.clone()), cx);
230 project.open_buffer(path, cx)
231 })
232 .await
233 .unwrap();
234 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
235 let position = snapshot1.anchor_before(language::Point::new(1, 3));
236
237 ep_store.update(cx, |ep_store, cx| {
238 ep_store.register_project(&project, cx);
239 ep_store.register_buffer(&buffer1, &project, cx);
240 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx);
241 });
242
243 let (request, respond_tx) = requests.predict.next().await.unwrap();
244 respond_tx
245 .send(model_response(
246 &request,
247 indoc! {r"
248 --- a/root/1.txt
249 +++ b/root/1.txt
250 @@ ... @@
251 Hello!
252 -How
253 +How are you?
254 Bye
255 "},
256 ))
257 .unwrap();
258 cx.run_until_parked();
259
260 ep_store.update(cx, |ep_store, cx| {
261 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
262 });
263
264 let _ = multi_workspace.update(cx, |multi_workspace, window, cx| {
265 multi_workspace.workspace().update(cx, |workspace, cx| {
266 workspace.start_following(CollaboratorId::Agent, window, cx);
267 });
268 });
269 cx.run_until_parked();
270
271 let diagnostic = lsp::Diagnostic {
272 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
273 severity: Some(lsp::DiagnosticSeverity::ERROR),
274 message: "Sentence is incomplete".to_string(),
275 ..Default::default()
276 };
277
278 project.update(cx, |project, cx| {
279 project.lsp_store().update(cx, |lsp_store, cx| {
280 lsp_store
281 .update_diagnostics(
282 LanguageServerId(0),
283 lsp::PublishDiagnosticsParams {
284 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
285 diagnostics: vec![diagnostic.clone()],
286 version: None,
287 },
288 None,
289 language::DiagnosticSourceKind::Pushed,
290 &[],
291 cx,
292 )
293 .unwrap();
294 });
295 });
296
297 cx.run_until_parked();
298 assert_no_predict_request_ready(&mut requests.predict);
299
300 let _ = multi_workspace.update(cx, |multi_workspace, window, cx| {
301 multi_workspace.workspace().update(cx, |workspace, cx| {
302 workspace.unfollow(CollaboratorId::Agent, window, cx);
303 });
304 });
305 cx.run_until_parked();
306
307 project.update(cx, |project, cx| {
308 project.lsp_store().update(cx, |lsp_store, cx| {
309 lsp_store
310 .update_diagnostics(
311 LanguageServerId(0),
312 lsp::PublishDiagnosticsParams {
313 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
314 diagnostics: vec![diagnostic],
315 version: None,
316 },
317 None,
318 language::DiagnosticSourceKind::Pushed,
319 &[],
320 cx,
321 )
322 .unwrap();
323 });
324 });
325
326 let (request, respond_tx) = requests.predict.next().await.unwrap();
327 respond_tx
328 .send(model_response(
329 &request,
330 indoc! {r#"
331 --- a/root/2.txt
332 +++ b/root/2.txt
333 @@ ... @@
334 Hola!
335 -Como
336 +Como estas?
337 Adios
338 "#},
339 ))
340 .unwrap();
341 cx.run_until_parked();
342
343 ep_store.update(cx, |ep_store, cx| {
344 let prediction = ep_store
345 .prediction_at(&buffer1, None, &project, cx)
346 .unwrap();
347 assert_matches!(
348 prediction,
349 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
350 );
351 });
352}
353
354#[gpui::test]
355async fn test_simple_request(cx: &mut TestAppContext) {
356 let (ep_store, mut requests) = init_test_with_fake_client(cx);
357 let fs = FakeFs::new(cx.executor());
358 fs.insert_tree(
359 "/root",
360 json!({
361 "foo.md": "Hello!\nHow\nBye\n"
362 }),
363 )
364 .await;
365 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
366
367 let buffer = project
368 .update(cx, |project, cx| {
369 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
370 project.open_buffer(path, cx)
371 })
372 .await
373 .unwrap();
374 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
375 let position = snapshot.anchor_before(language::Point::new(1, 3));
376
377 let prediction_task = ep_store.update(cx, |ep_store, cx| {
378 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
379 });
380
381 let (request, respond_tx) = requests.predict.next().await.unwrap();
382
383 // TODO Put back when we have a structured request again
384 // assert_eq!(
385 // request.excerpt_path.as_ref(),
386 // Path::new(path!("root/foo.md"))
387 // );
388 // assert_eq!(
389 // request.cursor_point,
390 // Point {
391 // line: Line(1),
392 // column: 3
393 // }
394 // );
395
396 respond_tx
397 .send(model_response(
398 &request,
399 indoc! { r"
400 --- a/root/foo.md
401 +++ b/root/foo.md
402 @@ ... @@
403 Hello!
404 -How
405 +How are you?
406 Bye
407 "},
408 ))
409 .unwrap();
410
411 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
412
413 assert_eq!(prediction.edits.len(), 1);
414 assert_eq!(
415 prediction.edits[0].0.to_point(&snapshot).start,
416 language::Point::new(1, 3)
417 );
418 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
419}
420
421#[gpui::test]
422async fn test_request_events(cx: &mut TestAppContext) {
423 let (ep_store, mut requests) = init_test_with_fake_client(cx);
424 let fs = FakeFs::new(cx.executor());
425 fs.insert_tree(
426 "/root",
427 json!({
428 "foo.md": "Hello!\n\nBye\n"
429 }),
430 )
431 .await;
432 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
433
434 let buffer = project
435 .update(cx, |project, cx| {
436 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
437 project.open_buffer(path, cx)
438 })
439 .await
440 .unwrap();
441
442 ep_store.update(cx, |ep_store, cx| {
443 ep_store.register_buffer(&buffer, &project, cx);
444 });
445
446 buffer.update(cx, |buffer, cx| {
447 buffer.edit(vec![(7..7, "How")], None, cx);
448 });
449
450 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
451 let position = snapshot.anchor_before(language::Point::new(1, 3));
452
453 let prediction_task = ep_store.update(cx, |ep_store, cx| {
454 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
455 });
456
457 let (request, respond_tx) = requests.predict.next().await.unwrap();
458
459 let prompt = prompt_from_request(&request);
460 assert!(
461 prompt.contains(indoc! {"
462 --- a/root/foo.md
463 +++ b/root/foo.md
464 @@ -1,3 +1,3 @@
465 Hello!
466 -
467 +How
468 Bye
469 "}),
470 "{prompt}"
471 );
472
473 respond_tx
474 .send(model_response(
475 &request,
476 indoc! {r#"
477 --- a/root/foo.md
478 +++ b/root/foo.md
479 @@ ... @@
480 Hello!
481 -How
482 +How are you?
483 Bye
484 "#},
485 ))
486 .unwrap();
487
488 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
489
490 assert_eq!(prediction.edits.len(), 1);
491 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
492}
493
494#[gpui::test]
495async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
496 let (ep_store, _requests) = init_test_with_fake_client(cx);
497 let fs = FakeFs::new(cx.executor());
498 fs.insert_tree(
499 "/root",
500 json!({
501 "foo.md": "Hello!\n\nBye\n"
502 }),
503 )
504 .await;
505 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
506
507 let buffer = project
508 .update(cx, |project, cx| {
509 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
510 project.open_buffer(path, cx)
511 })
512 .await
513 .unwrap();
514
515 ep_store.update(cx, |ep_store, cx| {
516 ep_store.register_buffer(&buffer, &project, cx);
517 });
518
519 // First burst: insert "How"
520 buffer.update(cx, |buffer, cx| {
521 buffer.edit(vec![(7..7, "How")], None, cx);
522 });
523
524 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
525 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
526 cx.run_until_parked();
527
528 // Second burst: append " are you?" immediately after "How" on the same line.
529 //
530 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
531 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
532 buffer.update(cx, |buffer, cx| {
533 buffer.edit(vec![(10..10, " are you?")], None, cx);
534 });
535
536 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
537 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
538 buffer.update(cx, |buffer, cx| {
539 buffer.edit(vec![(19..19, "!")], None, cx);
540 });
541
542 // With time-based splitting, there are two distinct events.
543 let events = ep_store.update(cx, |ep_store, cx| {
544 ep_store.edit_history_for_project(&project, cx)
545 });
546 assert_eq!(events.len(), 2);
547
548 let first_total_edit_range = buffer.read_with(cx, |buffer, _| {
549 events[0].total_edit_range.to_point(&buffer.snapshot())
550 });
551 assert_eq!(first_total_edit_range, Point::new(1, 0)..Point::new(1, 3));
552
553 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
554 assert_eq!(
555 diff.as_str(),
556 indoc! {"
557 @@ -1,3 +1,3 @@
558 Hello!
559 -
560 +How
561 Bye
562 "}
563 );
564
565 let second_total_edit_range = buffer.read_with(cx, |buffer, _| {
566 events[1].total_edit_range.to_point(&buffer.snapshot())
567 });
568 assert_eq!(second_total_edit_range, Point::new(1, 3)..Point::new(1, 13));
569
570 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
571 assert_eq!(
572 diff.as_str(),
573 indoc! {"
574 @@ -1,3 +1,3 @@
575 Hello!
576 -How
577 +How are you?!
578 Bye
579 "}
580 );
581}
582
583#[gpui::test]
584async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
585 let (ep_store, _requests) = init_test_with_fake_client(cx);
586 let fs = FakeFs::new(cx.executor());
587
588 // Create a file with 30 lines to test line-based coalescing
589 let content = (1..=30)
590 .map(|i| format!("Line {}\n", i))
591 .collect::<String>();
592 fs.insert_tree(
593 "/root",
594 json!({
595 "foo.md": content
596 }),
597 )
598 .await;
599 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
600
601 let buffer = project
602 .update(cx, |project, cx| {
603 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
604 project.open_buffer(path, cx)
605 })
606 .await
607 .unwrap();
608
609 ep_store.update(cx, |ep_store, cx| {
610 ep_store.register_buffer(&buffer, &project, cx);
611 });
612
613 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
614 buffer.update(cx, |buffer, cx| {
615 let start = Point::new(10, 0).to_offset(buffer);
616 let end = Point::new(13, 0).to_offset(buffer);
617 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
618 });
619
620 let events = ep_store.update(cx, |ep_store, cx| {
621 ep_store.edit_history_for_project(&project, cx)
622 });
623 assert_eq!(
624 render_events(&events),
625 indoc! {"
626 @@ -8,9 +8,8 @@
627 Line 8
628 Line 9
629 Line 10
630 -Line 11
631 -Line 12
632 -Line 13
633 +Middle A
634 +Middle B
635 Line 14
636 Line 15
637 Line 16
638 "},
639 "After first edit"
640 );
641
642 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
643 // This tests that coalescing considers the START of the existing range
644 buffer.update(cx, |buffer, cx| {
645 let offset = Point::new(5, 0).to_offset(buffer);
646 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
647 });
648
649 let events = ep_store.update(cx, |ep_store, cx| {
650 ep_store.edit_history_for_project(&project, cx)
651 });
652 assert_eq!(
653 render_events(&events),
654 indoc! {"
655 @@ -3,14 +3,14 @@
656 Line 3
657 Line 4
658 Line 5
659 +Above
660 Line 6
661 Line 7
662 Line 8
663 Line 9
664 Line 10
665 -Line 11
666 -Line 12
667 -Line 13
668 +Middle A
669 +Middle B
670 Line 14
671 Line 15
672 Line 16
673 "},
674 "After inserting above (should coalesce)"
675 );
676
677 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
678 // This tests that coalescing considers the END of the existing range
679 buffer.update(cx, |buffer, cx| {
680 let offset = Point::new(14, 0).to_offset(buffer);
681 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
682 });
683
684 let events = ep_store.update(cx, |ep_store, cx| {
685 ep_store.edit_history_for_project(&project, cx)
686 });
687 assert_eq!(
688 render_events(&events),
689 indoc! {"
690 @@ -3,15 +3,16 @@
691 Line 3
692 Line 4
693 Line 5
694 +Above
695 Line 6
696 Line 7
697 Line 8
698 Line 9
699 Line 10
700 -Line 11
701 -Line 12
702 -Line 13
703 +Middle A
704 +Middle B
705 Line 14
706 +Below
707 Line 15
708 Line 16
709 Line 17
710 "},
711 "After inserting below (should coalesce)"
712 );
713
714 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
715 // This should NOT coalesce - creates a new event
716 buffer.update(cx, |buffer, cx| {
717 let offset = Point::new(25, 0).to_offset(buffer);
718 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
719 });
720
721 let events = ep_store.update(cx, |ep_store, cx| {
722 ep_store.edit_history_for_project(&project, cx)
723 });
724 assert_eq!(
725 render_events(&events),
726 indoc! {"
727 @@ -3,15 +3,16 @@
728 Line 3
729 Line 4
730 Line 5
731 +Above
732 Line 6
733 Line 7
734 Line 8
735 Line 9
736 Line 10
737 -Line 11
738 -Line 12
739 -Line 13
740 +Middle A
741 +Middle B
742 Line 14
743 +Below
744 Line 15
745 Line 16
746 Line 17
747
748 ---
749 @@ -23,6 +23,7 @@
750 Line 22
751 Line 23
752 Line 24
753 +Far below
754 Line 25
755 Line 26
756 Line 27
757 "},
758 "After inserting far below (should NOT coalesce)"
759 );
760}
761
762fn render_events(events: &[StoredEvent]) -> String {
763 events
764 .iter()
765 .map(|e| {
766 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
767 diff.as_str()
768 })
769 .collect::<Vec<_>>()
770 .join("\n---\n")
771}
772
773fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
774 events
775 .iter()
776 .map(|e| {
777 let zeta_prompt::Event::BufferChange {
778 diff, predicted, ..
779 } = e.event.as_ref();
780 let prefix = if *predicted { "predicted" } else { "manual" };
781 format!("{}\n{}", prefix, diff)
782 })
783 .collect()
784}
785
786fn make_collaborator_replica(
787 buffer: &Entity<Buffer>,
788 cx: &mut TestAppContext,
789) -> (Entity<Buffer>, clock::Global) {
790 let (state, version) =
791 buffer.read_with(cx, |buffer, _cx| (buffer.to_proto(_cx), buffer.version()));
792 let collaborator = cx.new(|_cx| {
793 Buffer::from_proto(ReplicaId::new(1), Capability::ReadWrite, state, None).unwrap()
794 });
795 (collaborator, version)
796}
797
798async fn apply_collaborator_edit(
799 collaborator: &Entity<Buffer>,
800 buffer: &Entity<Buffer>,
801 since_version: &mut clock::Global,
802 edit_range: Range<usize>,
803 new_text: &str,
804 cx: &mut TestAppContext,
805) {
806 collaborator.update(cx, |collaborator, cx| {
807 collaborator.edit([(edit_range, new_text)], None, cx);
808 });
809
810 let serialize_task = collaborator.read_with(cx, |collaborator, cx| {
811 collaborator.serialize_ops(Some(since_version.clone()), cx)
812 });
813 let ops = serialize_task.await;
814 *since_version = collaborator.read_with(cx, |collaborator, _cx| collaborator.version());
815
816 buffer.update(cx, |buffer, cx| {
817 buffer.apply_ops(
818 ops.into_iter()
819 .map(|op| language::proto::deserialize_operation(op).unwrap()),
820 cx,
821 );
822 });
823}
824
825#[gpui::test]
826async fn test_nearby_collaborator_edits_are_kept_in_history(cx: &mut TestAppContext) {
827 let (ep_store, _requests) = init_test_with_fake_client(cx);
828 let fs = FakeFs::new(cx.executor());
829 fs.insert_tree(
830 "/root",
831 json!({
832 "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
833 }),
834 )
835 .await;
836 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
837
838 let buffer = project
839 .update(cx, |project, cx| {
840 let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
841 project.set_active_path(Some(path.clone()), cx);
842 project.open_buffer(path, cx)
843 })
844 .await
845 .unwrap();
846
847 let cursor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(Point::new(1, 0)));
848
849 ep_store.update(cx, |ep_store, cx| {
850 ep_store.register_buffer(&buffer, &project, cx);
851 let _ = ep_store.prediction_at(&buffer, Some(cursor), &project, cx);
852 });
853
854 buffer.update(cx, |buffer, cx| {
855 buffer.edit(vec![(0..6, "LOCAL ZERO")], None, cx);
856 });
857
858 let (collaborator, mut collaborator_version) = make_collaborator_replica(&buffer, cx);
859
860 let (line_one_start, line_one_len) = collaborator.read_with(cx, |buffer, _cx| {
861 (Point::new(1, 0).to_offset(buffer), buffer.line_len(1))
862 });
863
864 apply_collaborator_edit(
865 &collaborator,
866 &buffer,
867 &mut collaborator_version,
868 line_one_start..line_one_start + line_one_len as usize,
869 "REMOTE ONE",
870 cx,
871 )
872 .await;
873
874 let events = ep_store.update(cx, |ep_store, cx| {
875 ep_store.edit_history_for_project(&project, cx)
876 });
877
878 assert_eq!(
879 render_events_with_predicted(&events),
880 vec![indoc! {"
881 manual
882 @@ -1,5 +1,5 @@
883 -line 0
884 -line 1
885 +LOCAL ZERO
886 +REMOTE ONE
887 line 2
888 line 3
889 line 4
890 "}]
891 );
892}
893
894#[gpui::test]
895async fn test_distant_collaborator_edits_are_omitted_from_history(cx: &mut TestAppContext) {
896 let (ep_store, _requests) = init_test_with_fake_client(cx);
897 let fs = FakeFs::new(cx.executor());
898 fs.insert_tree(
899 "/root",
900 json!({
901 "foo.rs": (0..1000)
902 .map(|i| format!("line {i}\n"))
903 .collect::<String>()
904 }),
905 )
906 .await;
907 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
908
909 let buffer = project
910 .update(cx, |project, cx| {
911 let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
912 project.set_active_path(Some(path.clone()), cx);
913 project.open_buffer(path, cx)
914 })
915 .await
916 .unwrap();
917
918 let cursor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(Point::new(1, 0)));
919
920 ep_store.update(cx, |ep_store, cx| {
921 ep_store.register_buffer(&buffer, &project, cx);
922 let _ = ep_store.prediction_at(&buffer, Some(cursor), &project, cx);
923 });
924
925 buffer.update(cx, |buffer, cx| {
926 buffer.edit(vec![(0..6, "LOCAL ZERO")], None, cx);
927 });
928
929 let (collaborator, mut collaborator_version) = make_collaborator_replica(&buffer, cx);
930
931 let far_line_start = buffer.read_with(cx, |buffer, _cx| Point::new(900, 0).to_offset(buffer));
932
933 apply_collaborator_edit(
934 &collaborator,
935 &buffer,
936 &mut collaborator_version,
937 far_line_start..far_line_start + 7,
938 "REMOTE FAR",
939 cx,
940 )
941 .await;
942
943 let events = ep_store.update(cx, |ep_store, cx| {
944 ep_store.edit_history_for_project(&project, cx)
945 });
946
947 assert_eq!(
948 render_events_with_predicted(&events),
949 vec![indoc! {"
950 manual
951 @@ -1,4 +1,4 @@
952 -line 0
953 +LOCAL ZERO
954 line 1
955 line 2
956 line 3
957 "}]
958 );
959}
960
961#[gpui::test]
962async fn test_irrelevant_collaborator_edits_in_different_files_are_omitted_from_history(
963 cx: &mut TestAppContext,
964) {
965 let (ep_store, _requests) = init_test_with_fake_client(cx);
966 let fs = FakeFs::new(cx.executor());
967 fs.insert_tree(
968 "/root",
969 json!({
970 "foo.rs": "line 0\nline 1\nline 2\nline 3\n",
971 "bar.rs": "line 0\nline 1\nline 2\nline 3\n"
972 }),
973 )
974 .await;
975 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
976
977 let foo_buffer = project
978 .update(cx, |project, cx| {
979 let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
980 project.set_active_path(Some(path.clone()), cx);
981 project.open_buffer(path, cx)
982 })
983 .await
984 .unwrap();
985 let bar_buffer = project
986 .update(cx, |project, cx| {
987 let path = project.find_project_path(path!("root/bar.rs"), cx).unwrap();
988 project.open_buffer(path, cx)
989 })
990 .await
991 .unwrap();
992
993 let foo_cursor = foo_buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(Point::new(1, 0)));
994
995 ep_store.update(cx, |ep_store, cx| {
996 ep_store.register_buffer(&foo_buffer, &project, cx);
997 ep_store.register_buffer(&bar_buffer, &project, cx);
998 let _ = ep_store.prediction_at(&foo_buffer, Some(foo_cursor), &project, cx);
999 });
1000
1001 let (bar_collaborator, mut bar_version) = make_collaborator_replica(&bar_buffer, cx);
1002
1003 apply_collaborator_edit(
1004 &bar_collaborator,
1005 &bar_buffer,
1006 &mut bar_version,
1007 0..6,
1008 "REMOTE BAR",
1009 cx,
1010 )
1011 .await;
1012
1013 let events = ep_store.update(cx, |ep_store, cx| {
1014 ep_store.edit_history_for_project(&project, cx)
1015 });
1016
1017 assert!(events.is_empty());
1018}
1019
1020#[gpui::test]
1021async fn test_large_edits_are_omitted_from_history(cx: &mut TestAppContext) {
1022 let (ep_store, _requests) = init_test_with_fake_client(cx);
1023 let fs = FakeFs::new(cx.executor());
1024 fs.insert_tree(
1025 "/root",
1026 json!({
1027 "foo.rs": (0..20)
1028 .map(|i| format!("line {i}\n"))
1029 .collect::<String>()
1030 }),
1031 )
1032 .await;
1033 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1034
1035 let buffer = project
1036 .update(cx, |project, cx| {
1037 let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
1038 project.set_active_path(Some(path.clone()), cx);
1039 project.open_buffer(path, cx)
1040 })
1041 .await
1042 .unwrap();
1043
1044 let cursor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(Point::new(1, 0)));
1045
1046 ep_store.update(cx, |ep_store, cx| {
1047 ep_store.register_buffer(&buffer, &project, cx);
1048 let _ = ep_store.prediction_at(&buffer, Some(cursor), &project, cx);
1049 });
1050
1051 buffer.update(cx, |buffer, cx| {
1052 buffer.edit(vec![(0..6, "LOCAL ZERO")], None, cx);
1053 });
1054
1055 let (collaborator, mut collaborator_version) = make_collaborator_replica(&buffer, cx);
1056
1057 let (line_three_start, line_three_len) = collaborator.read_with(cx, |buffer, _cx| {
1058 (Point::new(3, 0).to_offset(buffer), buffer.line_len(3))
1059 });
1060 let large_edit = "X".repeat(EDIT_HISTORY_DIFF_SIZE_LIMIT + 1);
1061
1062 apply_collaborator_edit(
1063 &collaborator,
1064 &buffer,
1065 &mut collaborator_version,
1066 line_three_start..line_three_start + line_three_len as usize,
1067 &large_edit,
1068 cx,
1069 )
1070 .await;
1071
1072 buffer.update(cx, |buffer, cx| {
1073 let line_seven_start = Point::new(7, 0).to_offset(buffer);
1074 let line_seven_end = Point::new(7, 6).to_offset(buffer);
1075 buffer.edit(
1076 vec![(line_seven_start..line_seven_end, "LOCAL SEVEN")],
1077 None,
1078 cx,
1079 );
1080 });
1081
1082 let events = ep_store.update(cx, |ep_store, cx| {
1083 ep_store.edit_history_for_project(&project, cx)
1084 });
1085
1086 let rendered_events = render_events_with_predicted(&events);
1087
1088 assert_eq!(rendered_events.len(), 2);
1089 assert!(rendered_events[0].contains("+LOCAL ZERO"));
1090 assert!(!rendered_events[0].contains(&large_edit));
1091 assert!(rendered_events[1].contains("+LOCAL SEVEN"));
1092 assert!(!rendered_events[1].contains(&large_edit));
1093}
1094
1095#[gpui::test]
1096async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
1097 let (ep_store, _requests) = init_test_with_fake_client(cx);
1098 let fs = FakeFs::new(cx.executor());
1099 fs.insert_tree(
1100 "/root",
1101 json!({
1102 "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
1103 }),
1104 )
1105 .await;
1106 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1107
1108 let buffer = project
1109 .update(cx, |project, cx| {
1110 let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
1111 project.open_buffer(path, cx)
1112 })
1113 .await
1114 .unwrap();
1115
1116 ep_store.update(cx, |ep_store, cx| {
1117 ep_store.register_buffer(&buffer, &project, cx);
1118 });
1119
1120 // Case 1: Manual edits have `predicted` set to false.
1121 buffer.update(cx, |buffer, cx| {
1122 buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
1123 });
1124
1125 let events = ep_store.update(cx, |ep_store, cx| {
1126 ep_store.edit_history_for_project(&project, cx)
1127 });
1128
1129 assert_eq!(
1130 render_events_with_predicted(&events),
1131 vec![indoc! {"
1132 manual
1133 @@ -1,4 +1,4 @@
1134 -line 0
1135 +LINE ZERO
1136 line 1
1137 line 2
1138 line 3
1139 "}]
1140 );
1141
1142 // Case 2: Multiple successive manual edits near each other are merged into one
1143 // event with `predicted` set to false.
1144 buffer.update(cx, |buffer, cx| {
1145 let offset = Point::new(1, 0).to_offset(buffer);
1146 let end = Point::new(1, 6).to_offset(buffer);
1147 buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
1148 });
1149
1150 let events = ep_store.update(cx, |ep_store, cx| {
1151 ep_store.edit_history_for_project(&project, cx)
1152 });
1153 assert_eq!(
1154 render_events_with_predicted(&events),
1155 vec![indoc! {"
1156 manual
1157 @@ -1,5 +1,5 @@
1158 -line 0
1159 -line 1
1160 +LINE ZERO
1161 +LINE ONE
1162 line 2
1163 line 3
1164 line 4
1165 "}]
1166 );
1167
1168 // Case 3: Accepted predictions have `predicted` set to true.
1169 // Case 5: A manual edit that follows a predicted edit is not merged with the
1170 // predicted edit, even if it is nearby.
1171 ep_store.update(cx, |ep_store, cx| {
1172 buffer.update(cx, |buffer, cx| {
1173 let offset = Point::new(2, 0).to_offset(buffer);
1174 let end = Point::new(2, 6).to_offset(buffer);
1175 buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
1176 });
1177 ep_store.report_changes_for_buffer(&buffer, &project, true, true, cx);
1178 });
1179
1180 let events = ep_store.update(cx, |ep_store, cx| {
1181 ep_store.edit_history_for_project(&project, cx)
1182 });
1183 assert_eq!(
1184 render_events_with_predicted(&events),
1185 vec![
1186 indoc! {"
1187 manual
1188 @@ -1,5 +1,5 @@
1189 -line 0
1190 -line 1
1191 +LINE ZERO
1192 +LINE ONE
1193 line 2
1194 line 3
1195 line 4
1196 "},
1197 indoc! {"
1198 predicted
1199 @@ -1,6 +1,6 @@
1200 LINE ZERO
1201 LINE ONE
1202 -line 2
1203 +LINE TWO
1204 line 3
1205 line 4
1206 line 5
1207 "}
1208 ]
1209 );
1210
1211 // Case 4: Multiple successive accepted predictions near each other are merged
1212 // into one event with `predicted` set to true.
1213 ep_store.update(cx, |ep_store, cx| {
1214 buffer.update(cx, |buffer, cx| {
1215 let offset = Point::new(3, 0).to_offset(buffer);
1216 let end = Point::new(3, 6).to_offset(buffer);
1217 buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
1218 });
1219 ep_store.report_changes_for_buffer(&buffer, &project, true, true, cx);
1220 });
1221
1222 let events = ep_store.update(cx, |ep_store, cx| {
1223 ep_store.edit_history_for_project(&project, cx)
1224 });
1225 assert_eq!(
1226 render_events_with_predicted(&events),
1227 vec![
1228 indoc! {"
1229 manual
1230 @@ -1,5 +1,5 @@
1231 -line 0
1232 -line 1
1233 +LINE ZERO
1234 +LINE ONE
1235 line 2
1236 line 3
1237 line 4
1238 "},
1239 indoc! {"
1240 predicted
1241 @@ -1,7 +1,7 @@
1242 LINE ZERO
1243 LINE ONE
1244 -line 2
1245 -line 3
1246 +LINE TWO
1247 +LINE THREE
1248 line 4
1249 line 5
1250 line 6
1251 "}
1252 ]
1253 );
1254
1255 // Case 5 (continued): A manual edit that follows a predicted edit is not merged
1256 // with the predicted edit, even if it is nearby.
1257 buffer.update(cx, |buffer, cx| {
1258 let offset = Point::new(4, 0).to_offset(buffer);
1259 let end = Point::new(4, 6).to_offset(buffer);
1260 buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
1261 });
1262
1263 let events = ep_store.update(cx, |ep_store, cx| {
1264 ep_store.edit_history_for_project(&project, cx)
1265 });
1266 assert_eq!(
1267 render_events_with_predicted(&events),
1268 vec![
1269 indoc! {"
1270 manual
1271 @@ -1,5 +1,5 @@
1272 -line 0
1273 -line 1
1274 +LINE ZERO
1275 +LINE ONE
1276 line 2
1277 line 3
1278 line 4
1279 "},
1280 indoc! {"
1281 predicted
1282 @@ -1,7 +1,7 @@
1283 LINE ZERO
1284 LINE ONE
1285 -line 2
1286 -line 3
1287 +LINE TWO
1288 +LINE THREE
1289 line 4
1290 line 5
1291 line 6
1292 "},
1293 indoc! {"
1294 manual
1295 @@ -2,7 +2,7 @@
1296 LINE ONE
1297 LINE TWO
1298 LINE THREE
1299 -line 4
1300 +LINE FOUR
1301 line 5
1302 line 6
1303 line 7
1304 "}
1305 ]
1306 );
1307
1308 // Case 6: If we then perform a manual edit at a *different* location (more than
1309 // 8 lines away), then the edits at the prior location can be merged with each
1310 // other, even if some are predicted and some are not. `predicted` means all
1311 // constituent edits were predicted.
1312 buffer.update(cx, |buffer, cx| {
1313 let offset = Point::new(14, 0).to_offset(buffer);
1314 let end = Point::new(14, 7).to_offset(buffer);
1315 buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
1316 });
1317
1318 let events = ep_store.update(cx, |ep_store, cx| {
1319 ep_store.edit_history_for_project(&project, cx)
1320 });
1321 assert_eq!(
1322 render_events_with_predicted(&events),
1323 vec![
1324 indoc! {"
1325 manual
1326 @@ -1,8 +1,8 @@
1327 -line 0
1328 -line 1
1329 -line 2
1330 -line 3
1331 -line 4
1332 +LINE ZERO
1333 +LINE ONE
1334 +LINE TWO
1335 +LINE THREE
1336 +LINE FOUR
1337 line 5
1338 line 6
1339 line 7
1340 "},
1341 indoc! {"
1342 manual
1343 @@ -12,4 +12,4 @@
1344 line 11
1345 line 12
1346 line 13
1347 -line 14
1348 +LINE FOURTEEN
1349 "}
1350 ]
1351 );
1352}
1353
1354#[gpui::test]
1355async fn test_empty_prediction(cx: &mut TestAppContext) {
1356 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1357 let fs = FakeFs::new(cx.executor());
1358 fs.insert_tree(
1359 "/root",
1360 json!({
1361 "foo.md": "Hello!\nHow\nBye\n"
1362 }),
1363 )
1364 .await;
1365 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1366
1367 let buffer = project
1368 .update(cx, |project, cx| {
1369 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1370 project.open_buffer(path, cx)
1371 })
1372 .await
1373 .unwrap();
1374 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1375 let position = snapshot.anchor_before(language::Point::new(1, 3));
1376
1377 ep_store.update(cx, |ep_store, cx| {
1378 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1379 });
1380
1381 let (request, respond_tx) = requests.predict.next().await.unwrap();
1382 let mut response = model_response(&request, "");
1383 response.model_version = Some("zeta2:test-empty".to_string());
1384 let id = response.request_id.clone();
1385 respond_tx.send(response).unwrap();
1386
1387 cx.run_until_parked();
1388
1389 ep_store.update(cx, |ep_store, cx| {
1390 assert!(
1391 ep_store
1392 .prediction_at(&buffer, None, &project, cx)
1393 .is_none()
1394 );
1395 });
1396
1397 // prediction is reported as rejected
1398 let (reject_request, _) = requests.reject.next().await.unwrap();
1399
1400 assert_eq!(
1401 &reject_request.rejections,
1402 &[EditPredictionRejection {
1403 request_id: id,
1404 reason: EditPredictionRejectReason::Empty,
1405 was_shown: false,
1406 model_version: Some("zeta2:test-empty".to_string()),
1407 e2e_latency_ms: Some(0),
1408 }]
1409 );
1410}
1411
1412#[gpui::test]
1413async fn test_interpolated_empty(cx: &mut TestAppContext) {
1414 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1415 let fs = FakeFs::new(cx.executor());
1416 fs.insert_tree(
1417 "/root",
1418 json!({
1419 "foo.md": "Hello!\nHow\nBye\n"
1420 }),
1421 )
1422 .await;
1423 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1424
1425 let buffer = project
1426 .update(cx, |project, cx| {
1427 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1428 project.open_buffer(path, cx)
1429 })
1430 .await
1431 .unwrap();
1432 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1433 let position = snapshot.anchor_before(language::Point::new(1, 3));
1434
1435 ep_store.update(cx, |ep_store, cx| {
1436 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1437 });
1438
1439 let (request, respond_tx) = requests.predict.next().await.unwrap();
1440
1441 buffer.update(cx, |buffer, cx| {
1442 buffer.set_text("Hello!\nHow are you?\nBye", cx);
1443 });
1444
1445 let mut response = model_response(&request, SIMPLE_DIFF);
1446 response.model_version = Some("zeta2:test-interpolated-empty".to_string());
1447 let id = response.request_id.clone();
1448 respond_tx.send(response).unwrap();
1449
1450 cx.run_until_parked();
1451
1452 ep_store.update(cx, |ep_store, cx| {
1453 assert!(
1454 ep_store
1455 .prediction_at(&buffer, None, &project, cx)
1456 .is_none()
1457 );
1458 });
1459
1460 // prediction is reported as rejected
1461 let (reject_request, _) = requests.reject.next().await.unwrap();
1462
1463 assert_eq!(
1464 &reject_request.rejections,
1465 &[EditPredictionRejection {
1466 request_id: id,
1467 reason: EditPredictionRejectReason::InterpolatedEmpty,
1468 was_shown: false,
1469 model_version: Some("zeta2:test-interpolated-empty".to_string()),
1470 e2e_latency_ms: Some(0),
1471 }]
1472 );
1473}
1474
1475const SIMPLE_DIFF: &str = indoc! { r"
1476 --- a/root/foo.md
1477 +++ b/root/foo.md
1478 @@ ... @@
1479 Hello!
1480 -How
1481 +How are you?
1482 Bye
1483"};
1484
1485#[gpui::test]
1486async fn test_replace_current(cx: &mut TestAppContext) {
1487 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1488 let fs = FakeFs::new(cx.executor());
1489 fs.insert_tree(
1490 "/root",
1491 json!({
1492 "foo.md": "Hello!\nHow\nBye\n"
1493 }),
1494 )
1495 .await;
1496 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1497
1498 let buffer = project
1499 .update(cx, |project, cx| {
1500 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1501 project.open_buffer(path, cx)
1502 })
1503 .await
1504 .unwrap();
1505 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1506 let position = snapshot.anchor_before(language::Point::new(1, 3));
1507
1508 ep_store.update(cx, |ep_store, cx| {
1509 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1510 });
1511
1512 let (request, respond_tx) = requests.predict.next().await.unwrap();
1513 let first_response = model_response(&request, SIMPLE_DIFF);
1514 let first_id = first_response.request_id.clone();
1515 respond_tx.send(first_response).unwrap();
1516
1517 cx.run_until_parked();
1518
1519 ep_store.update(cx, |ep_store, cx| {
1520 assert_eq!(
1521 ep_store
1522 .prediction_at(&buffer, None, &project, cx)
1523 .unwrap()
1524 .id
1525 .0,
1526 first_id
1527 );
1528 });
1529
1530 // a second request is triggered
1531 ep_store.update(cx, |ep_store, cx| {
1532 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1533 });
1534
1535 let (request, respond_tx) = requests.predict.next().await.unwrap();
1536 let second_response = model_response(&request, SIMPLE_DIFF);
1537 let second_id = second_response.request_id.clone();
1538 respond_tx.send(second_response).unwrap();
1539
1540 cx.run_until_parked();
1541
1542 ep_store.update(cx, |ep_store, cx| {
1543 // second replaces first
1544 assert_eq!(
1545 ep_store
1546 .prediction_at(&buffer, None, &project, cx)
1547 .unwrap()
1548 .id
1549 .0,
1550 second_id
1551 );
1552 });
1553
1554 // first is reported as replaced
1555 let (reject_request, _) = requests.reject.next().await.unwrap();
1556
1557 assert_eq!(
1558 &reject_request.rejections,
1559 &[EditPredictionRejection {
1560 request_id: first_id,
1561 reason: EditPredictionRejectReason::Replaced,
1562 was_shown: false,
1563 model_version: None,
1564 e2e_latency_ms: Some(0),
1565 }]
1566 );
1567}
1568
1569#[gpui::test]
1570async fn test_current_preferred(cx: &mut TestAppContext) {
1571 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1572 let fs = FakeFs::new(cx.executor());
1573 fs.insert_tree(
1574 "/root",
1575 json!({
1576 "foo.md": "Hello!\nHow\nBye\n"
1577 }),
1578 )
1579 .await;
1580 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1581
1582 let buffer = project
1583 .update(cx, |project, cx| {
1584 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1585 project.open_buffer(path, cx)
1586 })
1587 .await
1588 .unwrap();
1589 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1590 let position = snapshot.anchor_before(language::Point::new(1, 3));
1591
1592 ep_store.update(cx, |ep_store, cx| {
1593 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1594 });
1595
1596 let (request, respond_tx) = requests.predict.next().await.unwrap();
1597 let first_response = model_response(&request, SIMPLE_DIFF);
1598 let first_id = first_response.request_id.clone();
1599 respond_tx.send(first_response).unwrap();
1600
1601 cx.run_until_parked();
1602
1603 ep_store.update(cx, |ep_store, cx| {
1604 assert_eq!(
1605 ep_store
1606 .prediction_at(&buffer, None, &project, cx)
1607 .unwrap()
1608 .id
1609 .0,
1610 first_id
1611 );
1612 });
1613
1614 // a second request is triggered
1615 ep_store.update(cx, |ep_store, cx| {
1616 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1617 });
1618
1619 let (request, respond_tx) = requests.predict.next().await.unwrap();
1620 // worse than current prediction
1621 let mut second_response = model_response(
1622 &request,
1623 indoc! { r"
1624 --- a/root/foo.md
1625 +++ b/root/foo.md
1626 @@ ... @@
1627 Hello!
1628 -How
1629 +How are
1630 Bye
1631 "},
1632 );
1633 second_response.model_version = Some("zeta2:test-current-preferred".to_string());
1634 let second_id = second_response.request_id.clone();
1635 respond_tx.send(second_response).unwrap();
1636
1637 cx.run_until_parked();
1638
1639 ep_store.update(cx, |ep_store, cx| {
1640 // first is preferred over second
1641 assert_eq!(
1642 ep_store
1643 .prediction_at(&buffer, None, &project, cx)
1644 .unwrap()
1645 .id
1646 .0,
1647 first_id
1648 );
1649 });
1650
1651 // second is reported as rejected
1652 let (reject_request, _) = requests.reject.next().await.unwrap();
1653
1654 assert_eq!(
1655 &reject_request.rejections,
1656 &[EditPredictionRejection {
1657 request_id: second_id,
1658 reason: EditPredictionRejectReason::CurrentPreferred,
1659 was_shown: false,
1660 model_version: Some("zeta2:test-current-preferred".to_string()),
1661 e2e_latency_ms: Some(0),
1662 }]
1663 );
1664}
1665
1666#[gpui::test]
1667async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1668 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1669 let fs = FakeFs::new(cx.executor());
1670 fs.insert_tree(
1671 "/root",
1672 json!({
1673 "foo.md": "Hello!\nHow\nBye\n"
1674 }),
1675 )
1676 .await;
1677 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1678
1679 let buffer = project
1680 .update(cx, |project, cx| {
1681 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1682 project.open_buffer(path, cx)
1683 })
1684 .await
1685 .unwrap();
1686 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1687 let position = snapshot.anchor_before(language::Point::new(1, 3));
1688
1689 // start two refresh tasks
1690 ep_store.update(cx, |ep_store, cx| {
1691 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1692 });
1693
1694 let (request1, respond_first) = requests.predict.next().await.unwrap();
1695
1696 ep_store.update(cx, |ep_store, cx| {
1697 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1698 });
1699
1700 let (request, respond_second) = requests.predict.next().await.unwrap();
1701
1702 // wait for throttle
1703 cx.run_until_parked();
1704
1705 // second responds first
1706 let second_response = model_response(&request, SIMPLE_DIFF);
1707 let second_id = second_response.request_id.clone();
1708 respond_second.send(second_response).unwrap();
1709
1710 cx.run_until_parked();
1711
1712 ep_store.update(cx, |ep_store, cx| {
1713 // current prediction is second
1714 assert_eq!(
1715 ep_store
1716 .prediction_at(&buffer, None, &project, cx)
1717 .unwrap()
1718 .id
1719 .0,
1720 second_id
1721 );
1722 });
1723
1724 let mut first_response = model_response(&request1, SIMPLE_DIFF);
1725 first_response.model_version = Some("zeta2:test-canceled".to_string());
1726 let first_id = first_response.request_id.clone();
1727 respond_first.send(first_response).unwrap();
1728
1729 cx.run_until_parked();
1730
1731 ep_store.update(cx, |ep_store, cx| {
1732 // current prediction is still second, since first was cancelled
1733 assert_eq!(
1734 ep_store
1735 .prediction_at(&buffer, None, &project, cx)
1736 .unwrap()
1737 .id
1738 .0,
1739 second_id
1740 );
1741 });
1742
1743 // first is reported as rejected
1744 let (reject_request, _) = requests.reject.next().await.unwrap();
1745
1746 cx.run_until_parked();
1747
1748 assert_eq!(
1749 &reject_request.rejections,
1750 &[EditPredictionRejection {
1751 request_id: first_id,
1752 reason: EditPredictionRejectReason::Canceled,
1753 was_shown: false,
1754 model_version: Some("zeta2:test-canceled".to_string()),
1755 e2e_latency_ms: None,
1756 }]
1757 );
1758}
1759
1760#[gpui::test]
1761async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1762 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1763 let fs = FakeFs::new(cx.executor());
1764 fs.insert_tree(
1765 "/root",
1766 json!({
1767 "foo.md": "Hello!\nHow\nBye\n"
1768 }),
1769 )
1770 .await;
1771 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1772
1773 let buffer = project
1774 .update(cx, |project, cx| {
1775 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1776 project.open_buffer(path, cx)
1777 })
1778 .await
1779 .unwrap();
1780 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1781 let position = snapshot.anchor_before(language::Point::new(1, 3));
1782
1783 // start two refresh tasks
1784 ep_store.update(cx, |ep_store, cx| {
1785 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1786 });
1787
1788 let (request1, respond_first) = requests.predict.next().await.unwrap();
1789
1790 ep_store.update(cx, |ep_store, cx| {
1791 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1792 });
1793
1794 let (request2, respond_second) = requests.predict.next().await.unwrap();
1795
1796 // wait for throttle, so requests are sent
1797 cx.run_until_parked();
1798
1799 ep_store.update(cx, |ep_store, cx| {
1800 // start a third request
1801 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1802
1803 // 2 are pending, so 2nd is cancelled
1804 assert_eq!(
1805 ep_store
1806 .get_or_init_project(&project, cx)
1807 .cancelled_predictions
1808 .iter()
1809 .copied()
1810 .collect::<Vec<_>>(),
1811 [1]
1812 );
1813 });
1814
1815 // wait for throttle
1816 cx.run_until_parked();
1817
1818 let (request3, respond_third) = requests.predict.next().await.unwrap();
1819
1820 let first_response = model_response(&request1, SIMPLE_DIFF);
1821 let first_id = first_response.request_id.clone();
1822 respond_first.send(first_response).unwrap();
1823
1824 cx.run_until_parked();
1825
1826 ep_store.update(cx, |ep_store, cx| {
1827 // current prediction is first
1828 assert_eq!(
1829 ep_store
1830 .prediction_at(&buffer, None, &project, cx)
1831 .unwrap()
1832 .id
1833 .0,
1834 first_id
1835 );
1836 });
1837
1838 let mut cancelled_response = model_response(&request2, SIMPLE_DIFF);
1839 cancelled_response.model_version = Some("zeta2:test-canceled-second".to_string());
1840 let cancelled_id = cancelled_response.request_id.clone();
1841 respond_second.send(cancelled_response).unwrap();
1842
1843 cx.run_until_parked();
1844
1845 ep_store.update(cx, |ep_store, cx| {
1846 // current prediction is still first, since second was cancelled
1847 assert_eq!(
1848 ep_store
1849 .prediction_at(&buffer, None, &project, cx)
1850 .unwrap()
1851 .id
1852 .0,
1853 first_id
1854 );
1855 });
1856
1857 let third_response = model_response(&request3, SIMPLE_DIFF);
1858 let third_response_id = third_response.request_id.clone();
1859 respond_third.send(third_response).unwrap();
1860
1861 cx.run_until_parked();
1862
1863 ep_store.update(cx, |ep_store, cx| {
1864 // third completes and replaces first
1865 assert_eq!(
1866 ep_store
1867 .prediction_at(&buffer, None, &project, cx)
1868 .unwrap()
1869 .id
1870 .0,
1871 third_response_id
1872 );
1873 });
1874
1875 // second is reported as rejected
1876 let (reject_request, _) = requests.reject.next().await.unwrap();
1877
1878 cx.run_until_parked();
1879
1880 assert_eq!(
1881 &reject_request.rejections,
1882 &[
1883 EditPredictionRejection {
1884 request_id: cancelled_id,
1885 reason: EditPredictionRejectReason::Canceled,
1886 was_shown: false,
1887 model_version: Some("zeta2:test-canceled-second".to_string()),
1888 e2e_latency_ms: None,
1889 },
1890 EditPredictionRejection {
1891 request_id: first_id,
1892 reason: EditPredictionRejectReason::Replaced,
1893 was_shown: false,
1894 model_version: None,
1895 // 2 throttle waits (for 2nd and 3rd requests) elapsed
1896 // between this request's start and response.
1897 e2e_latency_ms: Some(2 * EditPredictionStore::THROTTLE_TIMEOUT.as_millis()),
1898 }
1899 ]
1900 );
1901}
1902
1903#[gpui::test]
1904async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
1905 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1906
1907 let fs = FakeFs::new(cx.executor());
1908 fs.insert_tree(
1909 "/root",
1910 json!({
1911 "foo.md": "Hello!\nHow\nBye\n",
1912 "bar.md": "Hola!\nComo\nAdios\n"
1913 }),
1914 )
1915 .await;
1916 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1917
1918 let buffer = project
1919 .update(cx, |project, cx| {
1920 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1921 project.set_active_path(Some(path.clone()), cx);
1922 project.open_buffer(path, cx)
1923 })
1924 .await
1925 .unwrap();
1926 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1927 let position = snapshot.anchor_before(language::Point::new(1, 3));
1928
1929 ep_store.update(cx, |ep_store, cx| {
1930 ep_store.register_project(&project, cx);
1931 ep_store.register_buffer(&buffer, &project, cx);
1932 });
1933
1934 // First edit request - no prior edit, so not throttled.
1935 ep_store.update(cx, |ep_store, cx| {
1936 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1937 });
1938 let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
1939 edit_response_tx.send(empty_response()).unwrap();
1940 cx.run_until_parked();
1941
1942 let diagnostic = lsp::Diagnostic {
1943 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1944 severity: Some(lsp::DiagnosticSeverity::ERROR),
1945 message: "Sentence is incomplete".to_string(),
1946 ..Default::default()
1947 };
1948
1949 // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
1950 project.update(cx, |project, cx| {
1951 project.lsp_store().update(cx, |lsp_store, cx| {
1952 lsp_store
1953 .update_diagnostics(
1954 LanguageServerId(0),
1955 lsp::PublishDiagnosticsParams {
1956 uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
1957 diagnostics: vec![diagnostic],
1958 version: None,
1959 },
1960 None,
1961 language::DiagnosticSourceKind::Pushed,
1962 &[],
1963 cx,
1964 )
1965 .unwrap();
1966 });
1967 });
1968 let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
1969 jump_response_tx.send(empty_response()).unwrap();
1970 cx.run_until_parked();
1971
1972 // Second edit request - should be throttled by the first edit.
1973 ep_store.update(cx, |ep_store, cx| {
1974 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1975 });
1976 assert_no_predict_request_ready(&mut requests.predict);
1977
1978 // Second jump request - should be throttled by the first jump.
1979 ep_store.update(cx, |ep_store, cx| {
1980 ep_store.refresh_prediction_from_diagnostics(
1981 project.clone(),
1982 DiagnosticSearchScope::Global,
1983 cx,
1984 );
1985 });
1986 assert_no_predict_request_ready(&mut requests.predict);
1987
1988 // Wait for both throttles to expire.
1989 cx.background_executor
1990 .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
1991 cx.background_executor.run_until_parked();
1992 cx.run_until_parked();
1993
1994 // Both requests should now go through.
1995 let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
1996 response_tx_1.send(empty_response()).unwrap();
1997 cx.run_until_parked();
1998
1999 let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
2000 response_tx_2.send(empty_response()).unwrap();
2001 cx.run_until_parked();
2002}
2003
2004#[gpui::test]
2005async fn test_same_frame_duplicate_requests_deduplicated(cx: &mut TestAppContext) {
2006 let (ep_store, mut requests) = init_test_with_fake_client(cx);
2007 let fs = FakeFs::new(cx.executor());
2008 fs.insert_tree(
2009 "/root",
2010 json!({
2011 "foo.md": "Hello!\nHow\nBye\n"
2012 }),
2013 )
2014 .await;
2015 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2016
2017 let buffer = project
2018 .update(cx, |project, cx| {
2019 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2020 project.open_buffer(path, cx)
2021 })
2022 .await
2023 .unwrap();
2024 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2025 let position = snapshot.anchor_before(language::Point::new(1, 3));
2026
2027 // Enqueue two refresh calls in the same synchronous frame (no yielding).
2028 // Both `cx.spawn` tasks are created before either executes, so they both
2029 // capture the same `proceed_count_at_enqueue`. Only the first task should
2030 // pass the deduplication gate; the second should be skipped.
2031 ep_store.update(cx, |ep_store, cx| {
2032 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2033 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2034 });
2035
2036 // Let both spawned tasks run to completion (including any throttle waits).
2037 cx.run_until_parked();
2038
2039 // Exactly one prediction request should have been sent.
2040 let (request, respond_tx) = requests.predict.next().await.unwrap();
2041 respond_tx
2042 .send(model_response(&request, SIMPLE_DIFF))
2043 .unwrap();
2044 cx.run_until_parked();
2045
2046 // No second request should be pending.
2047 assert_no_predict_request_ready(&mut requests.predict);
2048}
2049
2050#[gpui::test]
2051async fn test_rejections_flushing(cx: &mut TestAppContext) {
2052 let (ep_store, mut requests) = init_test_with_fake_client(cx);
2053
2054 ep_store.update(cx, |ep_store, cx| {
2055 ep_store.reject_prediction(
2056 EditPredictionId("test-1".into()),
2057 EditPredictionRejectReason::Discarded,
2058 false,
2059 None,
2060 None,
2061 cx,
2062 );
2063 ep_store.reject_prediction(
2064 EditPredictionId("test-2".into()),
2065 EditPredictionRejectReason::Canceled,
2066 true,
2067 None,
2068 None,
2069 cx,
2070 );
2071 });
2072
2073 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
2074 cx.run_until_parked();
2075
2076 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
2077 respond_tx.send(()).unwrap();
2078
2079 // batched
2080 assert_eq!(reject_request.rejections.len(), 2);
2081 assert_eq!(
2082 reject_request.rejections[0],
2083 EditPredictionRejection {
2084 request_id: "test-1".to_string(),
2085 reason: EditPredictionRejectReason::Discarded,
2086 was_shown: false,
2087 model_version: None,
2088 e2e_latency_ms: None
2089 }
2090 );
2091 assert_eq!(
2092 reject_request.rejections[1],
2093 EditPredictionRejection {
2094 request_id: "test-2".to_string(),
2095 reason: EditPredictionRejectReason::Canceled,
2096 was_shown: true,
2097 model_version: None,
2098 e2e_latency_ms: None
2099 }
2100 );
2101
2102 // Reaching batch size limit sends without debounce
2103 ep_store.update(cx, |ep_store, cx| {
2104 for i in 0..70 {
2105 ep_store.reject_prediction(
2106 EditPredictionId(format!("batch-{}", i).into()),
2107 EditPredictionRejectReason::Discarded,
2108 false,
2109 None,
2110 None,
2111 cx,
2112 );
2113 }
2114 });
2115
2116 // First MAX/2 items are sent immediately
2117 cx.run_until_parked();
2118 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
2119 respond_tx.send(()).unwrap();
2120
2121 assert_eq!(reject_request.rejections.len(), 50);
2122 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
2123 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
2124
2125 // Remaining items are debounced with the next batch
2126 cx.executor().advance_clock(Duration::from_secs(15));
2127 cx.run_until_parked();
2128
2129 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
2130 respond_tx.send(()).unwrap();
2131
2132 assert_eq!(reject_request.rejections.len(), 20);
2133 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
2134 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
2135
2136 // Request failure
2137 ep_store.update(cx, |ep_store, cx| {
2138 ep_store.reject_prediction(
2139 EditPredictionId("retry-1".into()),
2140 EditPredictionRejectReason::Discarded,
2141 false,
2142 None,
2143 None,
2144 cx,
2145 );
2146 });
2147
2148 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
2149 cx.run_until_parked();
2150
2151 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
2152 assert_eq!(reject_request.rejections.len(), 1);
2153 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
2154 // Simulate failure
2155 drop(_respond_tx);
2156
2157 // Add another rejection
2158 ep_store.update(cx, |ep_store, cx| {
2159 ep_store.reject_prediction(
2160 EditPredictionId("retry-2".into()),
2161 EditPredictionRejectReason::Discarded,
2162 false,
2163 None,
2164 None,
2165 cx,
2166 );
2167 });
2168
2169 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
2170 cx.run_until_parked();
2171
2172 // Retry should include both the failed item and the new one
2173 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
2174 respond_tx.send(()).unwrap();
2175
2176 assert_eq!(reject_request.rejections.len(), 2);
2177 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
2178 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
2179}
2180
2181#[gpui::test]
2182fn test_active_buffer_diagnostics_fetching(cx: &mut TestAppContext) {
2183 let diagnostic_marker: TextRangeMarker = ('«', '»').into();
2184 let search_range_marker: TextRangeMarker = ('[', ']').into();
2185
2186 let (text, mut ranges) = marked_text_ranges_by(
2187 indoc! {r#"
2188 fn alpha() {
2189 let «first_value» = 1;
2190 }
2191
2192 [fn beta() {
2193 let «second_value» = 2;
2194 let third_value = second_value + missing_symbol;
2195 }ˇ]
2196
2197 fn gamma() {
2198 let «fourth_value» = missing_other_symbol;
2199 }
2200 "#},
2201 vec![diagnostic_marker.clone(), search_range_marker.clone()],
2202 );
2203
2204 let diagnostic_ranges = ranges.remove(&diagnostic_marker).unwrap_or_default();
2205 let search_ranges = ranges.remove(&search_range_marker).unwrap_or_default();
2206
2207 let buffer = cx.new(|cx| Buffer::local(&text, cx));
2208
2209 buffer.update(cx, |buffer, cx| {
2210 let snapshot = buffer.snapshot();
2211 let diagnostics = DiagnosticSet::new(
2212 diagnostic_ranges
2213 .iter()
2214 .enumerate()
2215 .map(|(index, range)| DiagnosticEntry {
2216 range: snapshot.offset_to_point_utf16(range.start)
2217 ..snapshot.offset_to_point_utf16(range.end),
2218 diagnostic: Diagnostic {
2219 severity: match index {
2220 0 => DiagnosticSeverity::WARNING,
2221 1 => DiagnosticSeverity::ERROR,
2222 _ => DiagnosticSeverity::HINT,
2223 },
2224 message: match index {
2225 0 => "first warning".to_string(),
2226 1 => "second error".to_string(),
2227 _ => "third hint".to_string(),
2228 },
2229 group_id: index + 1,
2230 is_primary: true,
2231 source_kind: language::DiagnosticSourceKind::Pushed,
2232 ..Diagnostic::default()
2233 },
2234 }),
2235 &snapshot,
2236 );
2237 buffer.update_diagnostics(LanguageServerId(0), diagnostics, cx);
2238 });
2239
2240 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2241 let search_range = snapshot.offset_to_point(search_ranges[0].start)
2242 ..snapshot.offset_to_point(search_ranges[0].end);
2243
2244 let active_buffer_diagnostics = zeta::active_buffer_diagnostics(&snapshot, search_range, 100);
2245
2246 assert_eq!(
2247 active_buffer_diagnostics,
2248 vec![zeta_prompt::ActiveBufferDiagnostic {
2249 severity: Some(1),
2250 message: "second error".to_string(),
2251 snippet: text,
2252 snippet_buffer_row_range: 5..5,
2253 diagnostic_range_in_snippet: 61..73,
2254 }]
2255 );
2256
2257 let buffer = cx.new(|cx| {
2258 Buffer::local(
2259 indoc! {"
2260 one
2261 two
2262 three
2263 four
2264 five
2265 "},
2266 cx,
2267 )
2268 });
2269
2270 buffer.update(cx, |buffer, cx| {
2271 let snapshot = buffer.snapshot();
2272 let diagnostics = DiagnosticSet::new(
2273 vec![
2274 DiagnosticEntry {
2275 range: text::PointUtf16::new(0, 0)..text::PointUtf16::new(0, 3),
2276 diagnostic: Diagnostic {
2277 severity: DiagnosticSeverity::ERROR,
2278 message: "row zero".to_string(),
2279 group_id: 1,
2280 is_primary: true,
2281 source_kind: language::DiagnosticSourceKind::Pushed,
2282 ..Diagnostic::default()
2283 },
2284 },
2285 DiagnosticEntry {
2286 range: text::PointUtf16::new(2, 0)..text::PointUtf16::new(2, 5),
2287 diagnostic: Diagnostic {
2288 severity: DiagnosticSeverity::WARNING,
2289 message: "row two".to_string(),
2290 group_id: 2,
2291 is_primary: true,
2292 source_kind: language::DiagnosticSourceKind::Pushed,
2293 ..Diagnostic::default()
2294 },
2295 },
2296 DiagnosticEntry {
2297 range: text::PointUtf16::new(4, 0)..text::PointUtf16::new(4, 4),
2298 diagnostic: Diagnostic {
2299 severity: DiagnosticSeverity::INFORMATION,
2300 message: "row four".to_string(),
2301 group_id: 3,
2302 is_primary: true,
2303 source_kind: language::DiagnosticSourceKind::Pushed,
2304 ..Diagnostic::default()
2305 },
2306 },
2307 ],
2308 &snapshot,
2309 );
2310 buffer.update_diagnostics(LanguageServerId(0), diagnostics, cx);
2311 });
2312
2313 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2314
2315 let active_buffer_diagnostics =
2316 zeta::active_buffer_diagnostics(&snapshot, Point::new(2, 0)..Point::new(4, 0), 100);
2317
2318 assert_eq!(
2319 active_buffer_diagnostics
2320 .iter()
2321 .map(|diagnostic| (
2322 diagnostic.severity,
2323 diagnostic.message.clone(),
2324 diagnostic.snippet.clone(),
2325 diagnostic.snippet_buffer_row_range.clone(),
2326 diagnostic.diagnostic_range_in_snippet.clone(),
2327 ))
2328 .collect::<Vec<_>>(),
2329 vec![
2330 (
2331 Some(2),
2332 "row two".to_string(),
2333 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2334 2..2,
2335 8..13,
2336 ),
2337 (
2338 Some(3),
2339 "row four".to_string(),
2340 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2341 4..4,
2342 19..23,
2343 ),
2344 ]
2345 );
2346}
2347
2348// Generate a model response that would apply the given diff to the active file.
2349fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
2350 let editable_range =
2351 zeta_prompt::excerpt_range_for_format(Default::default(), &request.input.excerpt_ranges).1;
2352 let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string();
2353 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
2354
2355 PredictEditsV3Response {
2356 request_id: Uuid::new_v4().to_string(),
2357 editable_range,
2358 output: new_excerpt,
2359 model_version: None,
2360 }
2361}
2362
2363fn empty_response() -> PredictEditsV3Response {
2364 PredictEditsV3Response {
2365 request_id: Uuid::new_v4().to_string(),
2366 editable_range: 0..0,
2367 output: String::new(),
2368 model_version: None,
2369 }
2370}
2371
2372fn prompt_from_request(request: &PredictEditsV3Request) -> String {
2373 zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
2374 .expect("default zeta prompt formatting should succeed in edit prediction tests")
2375}
2376
2377fn assert_no_predict_request_ready(
2378 requests: &mut mpsc::UnboundedReceiver<(
2379 PredictEditsV3Request,
2380 oneshot::Sender<PredictEditsV3Response>,
2381 )>,
2382) {
2383 if requests.next().now_or_never().flatten().is_some() {
2384 panic!("Unexpected prediction request while throttled.");
2385 }
2386}
2387
2388struct RequestChannels {
2389 predict: mpsc::UnboundedReceiver<(
2390 PredictEditsV3Request,
2391 oneshot::Sender<PredictEditsV3Response>,
2392 )>,
2393 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
2394}
2395
2396fn init_test_with_fake_client(
2397 cx: &mut TestAppContext,
2398) -> (Entity<EditPredictionStore>, RequestChannels) {
2399 init_test_with_fake_client_and_legacy_data_collection(cx, None)
2400}
2401
2402fn init_test_with_fake_client_and_legacy_data_collection(
2403 cx: &mut TestAppContext,
2404 legacy_data_collection_choice: Option<&str>,
2405) -> (Entity<EditPredictionStore>, RequestChannels) {
2406 cx.update(move |cx| {
2407 cx.set_global(AppDatabase::test_new());
2408 let settings_store = SettingsStore::test(cx);
2409 cx.set_global(settings_store);
2410 zlog::init_test();
2411
2412 if let Some(legacy_data_collection_choice) = legacy_data_collection_choice {
2413 KeyValueStore::global(cx)
2414 .write_kvp(
2415 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2416 legacy_data_collection_choice.to_string(),
2417 )
2418 .now_or_never()
2419 .expect("legacy data collection write should complete immediately")
2420 .expect("legacy data collection write should succeed");
2421 }
2422
2423 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
2424 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
2425
2426 let http_client = FakeHttpClient::create({
2427 move |req| {
2428 let uri = req.uri().path().to_string();
2429 let mut body = req.into_body();
2430 let predict_req_tx = predict_req_tx.clone();
2431 let reject_req_tx = reject_req_tx.clone();
2432 async move {
2433 let resp = match uri.as_str() {
2434 "/client/llm_tokens" => serde_json::to_string(&json!({
2435 "token": "test"
2436 }))
2437 .unwrap(),
2438 "/predict_edits/v3" => {
2439 let mut buf = Vec::new();
2440 body.read_to_end(&mut buf).await.ok();
2441 let decompressed = zstd::decode_all(&buf[..]).unwrap();
2442 let req = serde_json::from_slice(&decompressed).unwrap();
2443
2444 let (res_tx, res_rx) = oneshot::channel();
2445 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
2446 serde_json::to_string(&res_rx.await?).unwrap()
2447 }
2448 "/predict_edits/reject" => {
2449 let mut buf = Vec::new();
2450 body.read_to_end(&mut buf).await.ok();
2451 let req = serde_json::from_slice(&buf).unwrap();
2452
2453 let (res_tx, res_rx) = oneshot::channel();
2454 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
2455 serde_json::to_string(&res_rx.await?).unwrap()
2456 }
2457 _ => {
2458 panic!("Unexpected path: {}", uri)
2459 }
2460 };
2461
2462 Ok(Response::builder().body(resp.into()).unwrap())
2463 }
2464 }
2465 });
2466
2467 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2468 client.cloud_client().set_credentials(1, "test".into());
2469
2470 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2471 language_model::init(cx);
2472 RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
2473 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
2474
2475 (
2476 ep_store,
2477 RequestChannels {
2478 predict: predict_req_rx,
2479 reject: reject_req_rx,
2480 },
2481 )
2482 })
2483}
2484
2485#[gpui::test]
2486async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
2487 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
2488 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
2489 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
2490 });
2491
2492 let edit_preview = cx
2493 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
2494 .await;
2495
2496 let prediction = EditPrediction {
2497 edits,
2498 cursor_position: None,
2499 edit_preview,
2500 buffer: buffer.clone(),
2501 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
2502 id: EditPredictionId("the-id".into()),
2503 inputs: ZetaPromptInput {
2504 events: Default::default(),
2505 related_files: Default::default(),
2506 active_buffer_diagnostics: vec![],
2507 cursor_path: Path::new("").into(),
2508 cursor_excerpt: "".into(),
2509 cursor_offset_in_excerpt: 0,
2510 excerpt_start_row: None,
2511 excerpt_ranges: Default::default(),
2512 syntax_ranges: None,
2513 in_open_source_repo: false,
2514 can_collect_data: false,
2515 repo_url: None,
2516 },
2517 model_version: None,
2518 };
2519
2520 cx.update(|cx| {
2521 assert_eq!(
2522 from_completion_edits(
2523 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2524 &buffer,
2525 cx
2526 ),
2527 vec![(2..5, "REM".into()), (9..11, "".into())]
2528 );
2529
2530 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
2531 assert_eq!(
2532 from_completion_edits(
2533 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2534 &buffer,
2535 cx
2536 ),
2537 vec![(2..2, "REM".into()), (6..8, "".into())]
2538 );
2539
2540 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2541 assert_eq!(
2542 from_completion_edits(
2543 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2544 &buffer,
2545 cx
2546 ),
2547 vec![(2..5, "REM".into()), (9..11, "".into())]
2548 );
2549
2550 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
2551 assert_eq!(
2552 from_completion_edits(
2553 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2554 &buffer,
2555 cx
2556 ),
2557 vec![(3..3, "EM".into()), (7..9, "".into())]
2558 );
2559
2560 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
2561 assert_eq!(
2562 from_completion_edits(
2563 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2564 &buffer,
2565 cx
2566 ),
2567 vec![(4..4, "M".into()), (8..10, "".into())]
2568 );
2569
2570 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2571 assert_eq!(
2572 from_completion_edits(
2573 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2574 &buffer,
2575 cx
2576 ),
2577 vec![(9..11, "".into())]
2578 );
2579
2580 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2581 assert_eq!(
2582 from_completion_edits(
2583 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2584 &buffer,
2585 cx
2586 ),
2587 vec![(4..4, "M".into()), (8..10, "".into())]
2588 );
2589
2590 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2591 assert_eq!(
2592 from_completion_edits(
2593 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2594 &buffer,
2595 cx
2596 ),
2597 vec![(4..4, "M".into())]
2598 );
2599
2600 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2601 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
2602 })
2603}
2604
2605#[gpui::test]
2606async fn test_clean_up_diff(cx: &mut TestAppContext) {
2607 init_test(cx);
2608
2609 assert_eq!(
2610 apply_edit_prediction(
2611 indoc! {"
2612 fn main() {
2613 let word_1 = \"lorem\";
2614 let range = word.len()..word.len();
2615 }
2616 "},
2617 indoc! {"
2618 fn main() {
2619 let word_1 = \"lorem\";
2620 let range = word_1.len()..word_1.len();
2621 }
2622 "},
2623 cx,
2624 )
2625 .await,
2626 indoc! {"
2627 fn main() {
2628 let word_1 = \"lorem\";
2629 let range = word_1.len()..word_1.len();
2630 }
2631 "},
2632 );
2633
2634 assert_eq!(
2635 apply_edit_prediction(
2636 indoc! {"
2637 fn main() {
2638 let story = \"the quick\"
2639 }
2640 "},
2641 indoc! {"
2642 fn main() {
2643 let story = \"the quick brown fox jumps over the lazy dog\";
2644 }
2645 "},
2646 cx,
2647 )
2648 .await,
2649 indoc! {"
2650 fn main() {
2651 let story = \"the quick brown fox jumps over the lazy dog\";
2652 }
2653 "},
2654 );
2655}
2656
2657#[gpui::test]
2658async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2659 init_test(cx);
2660
2661 let buffer_content = "lorem\n";
2662 let completion_response = "lorem\nipsum\n";
2663
2664 assert_eq!(
2665 apply_edit_prediction(buffer_content, completion_response, cx).await,
2666 "lorem\nipsum\n"
2667 );
2668}
2669
2670#[gpui::test]
2671async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
2672 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
2673 // When the buffer ends without a trailing newline, but the model returns output
2674 // with a trailing newline, zeta2 should normalize both sides before diffing
2675 // so no spurious newline is inserted.
2676 let (ep_store, mut requests) = init_test_with_fake_client(cx);
2677 let fs = FakeFs::new(cx.executor());
2678
2679 // Single line buffer with no trailing newline
2680 fs.insert_tree(
2681 "/root",
2682 json!({
2683 "foo.txt": "hello"
2684 }),
2685 )
2686 .await;
2687 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2688
2689 let buffer = project
2690 .update(cx, |project, cx| {
2691 let path = project
2692 .find_project_path(path!("root/foo.txt"), cx)
2693 .unwrap();
2694 project.open_buffer(path, cx)
2695 })
2696 .await
2697 .unwrap();
2698
2699 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2700 let position = snapshot.anchor_before(language::Point::new(0, 5));
2701
2702 ep_store.update(cx, |ep_store, cx| {
2703 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2704 });
2705
2706 let (request, respond_tx) = requests.predict.next().await.unwrap();
2707
2708 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
2709 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
2710 let excerpt_length = request.input.cursor_excerpt.len();
2711 let response = PredictEditsV3Response {
2712 request_id: Uuid::new_v4().to_string(),
2713 output: "hello world\n".to_string(),
2714 editable_range: 0..excerpt_length,
2715 model_version: None,
2716 };
2717 respond_tx.send(response).unwrap();
2718
2719 cx.run_until_parked();
2720
2721 // The prediction should insert " world" without adding a newline
2722 ep_store.update(cx, |ep_store, cx| {
2723 let prediction = ep_store
2724 .prediction_at(&buffer, None, &project, cx)
2725 .expect("should have prediction");
2726 let edits: Vec<_> = prediction
2727 .edits
2728 .iter()
2729 .map(|(range, text)| {
2730 let snapshot = buffer.read(cx).snapshot();
2731 (range.to_offset(&snapshot), text.clone())
2732 })
2733 .collect();
2734 assert_eq!(edits, vec![(5..5, " world".into())]);
2735 });
2736}
2737
2738#[gpui::test]
2739async fn test_v3_prediction_strips_cursor_marker_from_edit_text(cx: &mut TestAppContext) {
2740 let (ep_store, mut requests) = init_test_with_fake_client(cx);
2741 let fs = FakeFs::new(cx.executor());
2742
2743 fs.insert_tree(
2744 "/root",
2745 json!({
2746 "foo.txt": "hello"
2747 }),
2748 )
2749 .await;
2750 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2751
2752 let buffer = project
2753 .update(cx, |project, cx| {
2754 let path = project
2755 .find_project_path(path!("root/foo.txt"), cx)
2756 .unwrap();
2757 project.open_buffer(path, cx)
2758 })
2759 .await
2760 .unwrap();
2761
2762 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2763 let position = snapshot.anchor_before(language::Point::new(0, 5));
2764
2765 ep_store.update(cx, |ep_store, cx| {
2766 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2767 });
2768
2769 let (request, respond_tx) = requests.predict.next().await.unwrap();
2770 let excerpt_length = request.input.cursor_excerpt.len();
2771 respond_tx
2772 .send(PredictEditsV3Response {
2773 request_id: Uuid::new_v4().to_string(),
2774 output: "hello<|user_cursor|> world".to_string(),
2775 editable_range: 0..excerpt_length,
2776 model_version: None,
2777 })
2778 .unwrap();
2779
2780 cx.run_until_parked();
2781
2782 ep_store.update(cx, |ep_store, cx| {
2783 let prediction = ep_store
2784 .prediction_at(&buffer, None, &project, cx)
2785 .expect("should have prediction");
2786 let snapshot = buffer.read(cx).snapshot();
2787 let edits: Vec<_> = prediction
2788 .edits
2789 .iter()
2790 .map(|(range, text)| (range.to_offset(&snapshot), text.clone()))
2791 .collect();
2792
2793 assert_eq!(edits, vec![(5..5, " world".into())]);
2794 });
2795}
2796
2797fn init_test(cx: &mut TestAppContext) {
2798 cx.update(|cx| {
2799 cx.set_global(AppDatabase::test_new());
2800 let settings_store = SettingsStore::test(cx);
2801 cx.set_global(settings_store);
2802 });
2803}
2804
2805async fn apply_edit_prediction(
2806 buffer_content: &str,
2807 completion_response: &str,
2808 cx: &mut TestAppContext,
2809) -> String {
2810 let fs = project::FakeFs::new(cx.executor());
2811 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2812 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2813 let (ep_store, response) = make_test_ep_store(&project, cx).await;
2814 *response.lock() = completion_response.to_string();
2815 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2816 buffer.update(cx, |buffer, cx| {
2817 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2818 });
2819 buffer.read_with(cx, |buffer, _| buffer.text())
2820}
2821
2822async fn run_edit_prediction(
2823 buffer: &Entity<Buffer>,
2824 project: &Entity<Project>,
2825 ep_store: &Entity<EditPredictionStore>,
2826 cx: &mut TestAppContext,
2827) -> EditPrediction {
2828 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2829 ep_store.update(cx, |ep_store, cx| {
2830 ep_store.register_buffer(buffer, &project, cx)
2831 });
2832 cx.background_executor.run_until_parked();
2833 let prediction_task = ep_store.update(cx, |ep_store, cx| {
2834 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2835 });
2836 prediction_task.await.unwrap().unwrap().prediction.unwrap()
2837}
2838
2839async fn make_test_ep_store(
2840 project: &Entity<Project>,
2841 cx: &mut TestAppContext,
2842) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
2843 let default_response = "hello world\n".to_string();
2844 let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
2845 let http_client = FakeHttpClient::create({
2846 let completion_response = completion_response.clone();
2847 let mut next_request_id = 0;
2848 move |req| {
2849 let completion_response = completion_response.clone();
2850 let method = req.method().clone();
2851 let uri = req.uri().path().to_string();
2852 let mut body = req.into_body();
2853 async move {
2854 match (method, uri.as_str()) {
2855 (Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2856 .status(200)
2857 .body(
2858 serde_json::to_string(&CreateLlmTokenResponse {
2859 token: LlmToken("the-llm-token".to_string()),
2860 })
2861 .unwrap()
2862 .into(),
2863 )
2864 .unwrap()),
2865 (Method::POST, "/predict_edits/v3") => {
2866 let mut buf = Vec::new();
2867 body.read_to_end(&mut buf).await.ok();
2868 let decompressed = zstd::decode_all(&buf[..]).unwrap();
2869 let req: PredictEditsV3Request =
2870 serde_json::from_slice(&decompressed).unwrap();
2871
2872 next_request_id += 1;
2873 Ok(http_client::Response::builder()
2874 .status(200)
2875 .body(
2876 serde_json::to_string(&PredictEditsV3Response {
2877 request_id: format!("request-{next_request_id}"),
2878 editable_range: 0..req.input.cursor_excerpt.len(),
2879 output: completion_response.lock().clone(),
2880 model_version: None,
2881 })
2882 .unwrap()
2883 .into(),
2884 )
2885 .unwrap())
2886 }
2887 _ => Ok(http_client::Response::builder()
2888 .status(404)
2889 .body("Not Found".to_string().into())
2890 .unwrap()),
2891 }
2892 }
2893 }
2894 });
2895
2896 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2897 let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
2898 cx.update(|cx| {
2899 RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
2900 });
2901 let _server = FakeServer::for_client(42, &client, cx).await;
2902
2903 let ep_store = cx.new(|cx| {
2904 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2905 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2906
2907 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2908 for worktree in worktrees {
2909 let worktree_id = worktree.read(cx).id();
2910 ep_store
2911 .get_or_init_project(project, cx)
2912 .license_detection_watchers
2913 .entry(worktree_id)
2914 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2915 }
2916
2917 ep_store
2918 });
2919
2920 (ep_store, completion_response)
2921}
2922
2923fn to_completion_edits(
2924 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2925 buffer: &Entity<Buffer>,
2926 cx: &App,
2927) -> Vec<(Range<Anchor>, Arc<str>)> {
2928 let buffer = buffer.read(cx);
2929 iterator
2930 .into_iter()
2931 .map(|(range, text)| {
2932 (
2933 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2934 text,
2935 )
2936 })
2937 .collect()
2938}
2939
2940fn from_completion_edits(
2941 editor_edits: &[(Range<Anchor>, Arc<str>)],
2942 buffer: &Entity<Buffer>,
2943 cx: &App,
2944) -> Vec<(Range<usize>, Arc<str>)> {
2945 let buffer = buffer.read(cx);
2946 editor_edits
2947 .iter()
2948 .map(|(range, text)| {
2949 (
2950 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2951 text.clone(),
2952 )
2953 })
2954 .collect()
2955}
2956
2957#[gpui::test]
2958async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2959 init_test(cx);
2960
2961 let fs = FakeFs::new(cx.executor());
2962 fs.insert_tree(
2963 "/project",
2964 serde_json::json!({
2965 "main.rs": "fn main() {\n \n}\n"
2966 }),
2967 )
2968 .await;
2969
2970 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2971
2972 let http_client = FakeHttpClient::create(|_req| async move {
2973 Ok(gpui::http_client::Response::builder()
2974 .status(401)
2975 .body("Unauthorized".into())
2976 .unwrap())
2977 });
2978
2979 let client =
2980 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2981 let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
2982 cx.update(|cx| {
2983 RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
2984 });
2985
2986 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2987
2988 let buffer = project
2989 .update(cx, |project, cx| {
2990 let path = project
2991 .find_project_path(path!("/project/main.rs"), cx)
2992 .unwrap();
2993 project.open_buffer(path, cx)
2994 })
2995 .await
2996 .unwrap();
2997
2998 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2999 ep_store.update(cx, |ep_store, cx| {
3000 ep_store.register_buffer(&buffer, &project, cx)
3001 });
3002 cx.background_executor.run_until_parked();
3003
3004 let completion_task = ep_store.update(cx, |ep_store, cx| {
3005 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
3006 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
3007 });
3008
3009 let result = completion_task.await;
3010 assert!(
3011 result.is_err(),
3012 "Without authentication and without custom URL, prediction should fail"
3013 );
3014}
3015
3016#[gpui::test]
3017async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
3018 fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
3019 let collab_replica = clock::ReplicaId::new(10);
3020 let anchor = buffer.read_with(cx, |buffer, _| {
3021 buffer.snapshot().anchor_before(Point::new(row, 0))
3022 });
3023 let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
3024 id: 1,
3025 start: anchor,
3026 end: anchor,
3027 reversed: false,
3028 goal: SelectionGoal::None,
3029 }]);
3030 buffer.update(cx, |buffer, cx| {
3031 buffer.apply_ops(
3032 [Operation::UpdateSelections {
3033 selections,
3034 lamport_timestamp: clock::Lamport {
3035 replica_id: collab_replica,
3036 value: 1,
3037 },
3038 line_mode: false,
3039 cursor_shape: CursorShape::Bar,
3040 }],
3041 cx,
3042 );
3043 });
3044 }
3045
3046 fn publish_diagnostics(
3047 uri_path: &'static str,
3048 rows: &[u32],
3049 project: &Entity<Project>,
3050 cx: &mut TestAppContext,
3051 ) {
3052 let diagnostics: Vec<_> = rows
3053 .iter()
3054 .map(|&row| lsp::Diagnostic {
3055 range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
3056 severity: Some(lsp::DiagnosticSeverity::ERROR),
3057 message: format!("error at row {row}"),
3058 ..Default::default()
3059 })
3060 .collect();
3061 project.update(cx, |project, cx| {
3062 project.lsp_store().update(cx, |lsp_store, cx| {
3063 lsp_store
3064 .update_diagnostics(
3065 LanguageServerId(0),
3066 lsp::PublishDiagnosticsParams {
3067 uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
3068 diagnostics,
3069 version: None,
3070 },
3071 None,
3072 language::DiagnosticSourceKind::Pushed,
3073 &[],
3074 cx,
3075 )
3076 .expect("failed to update diagnostics");
3077 });
3078 });
3079 }
3080
3081 init_test(cx);
3082
3083 let mut lines = String::new();
3084 for i in 0..60 {
3085 lines.push_str(&format!("line {i}\n"));
3086 }
3087
3088 let fs = FakeFs::new(cx.executor());
3089 fs.insert_tree(
3090 "/root",
3091 json!({
3092 "active.txt": lines,
3093 "collab_file.txt": "error here\nsecond line\n",
3094 "free_file.txt": "another error\nsecond line\n",
3095 }),
3096 )
3097 .await;
3098 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3099
3100 let active_buffer = project
3101 .update(cx, |project, cx| {
3102 let path = project
3103 .find_project_path(path!("/root/active.txt"), cx)
3104 .expect("active.txt not found");
3105 project.set_active_path(Some(path.clone()), cx);
3106 project.open_buffer(path, cx)
3107 })
3108 .await
3109 .expect("failed to open active buffer");
3110
3111 set_collaborator_cursor(&active_buffer, 5, cx);
3112
3113 publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
3114
3115 cx.run_until_parked();
3116
3117 let cursor_point = Point::new(25, 0);
3118 let empty_search_range: Range<Point> = Default::default();
3119
3120 let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
3121 let result = EditPredictionStore::next_diagnostic_location(
3122 active_buffer.clone(),
3123 &snapshot,
3124 empty_search_range.clone(),
3125 cursor_point,
3126 &project,
3127 &mut cx.to_async(),
3128 )
3129 .await
3130 .expect("next_diagnostic_location failed");
3131
3132 let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
3133 assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
3134 let result_row = result_buffer.read_with(cx, |buffer, _| {
3135 result_anchor.to_point(&buffer.snapshot()).row
3136 });
3137 assert_ne!(
3138 result_row, 3,
3139 "row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
3140 );
3141 assert!(
3142 result_row == 25 || result_row == 50,
3143 "expected row 25 or 50, got {result_row}"
3144 );
3145
3146 let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
3147 let near_cursor_point = Point::new(4, 0);
3148 let result_near = EditPredictionStore::next_diagnostic_location(
3149 active_buffer.clone(),
3150 &snapshot_near,
3151 empty_search_range.clone(),
3152 near_cursor_point,
3153 &project,
3154 &mut cx.to_async(),
3155 )
3156 .await
3157 .expect("next_diagnostic_location failed");
3158
3159 let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
3160 let near_row =
3161 active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
3162 assert_eq!(
3163 near_row, 3,
3164 "row 3 should be included when local cursor (row 4) is also near the collaborator"
3165 );
3166
3167 let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
3168 let far_cursor_point = Point::new(50, 0);
3169 let result_far = EditPredictionStore::next_diagnostic_location(
3170 active_buffer.clone(),
3171 &snapshot_far,
3172 empty_search_range.clone(),
3173 far_cursor_point,
3174 &project,
3175 &mut cx.to_async(),
3176 )
3177 .await
3178 .expect("next_diagnostic_location failed");
3179
3180 let (_, far_anchor) = result_far.expect("expected a diagnostic location");
3181 let far_row =
3182 active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
3183 assert_eq!(
3184 far_row, 50,
3185 "row 50 is near local cursor (row 50) and far from collaborator, should be picked"
3186 );
3187
3188 publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
3189 publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
3190 cx.run_until_parked();
3191
3192 let collab_buffer = project
3193 .update(cx, |project, cx| {
3194 let path = project
3195 .find_project_path(path!("/root/collab_file.txt"), cx)
3196 .expect("collab_file.txt not found");
3197 project.open_buffer(path, cx)
3198 })
3199 .await
3200 .expect("failed to open collab buffer");
3201
3202 set_collaborator_cursor(&collab_buffer, 0, cx);
3203 cx.run_until_parked();
3204
3205 let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
3206 let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
3207 let result_cross = EditPredictionStore::next_diagnostic_location(
3208 active_buffer.clone(),
3209 &snapshot_cross,
3210 no_same_file_search_range,
3211 Point::new(0, 0),
3212 &project,
3213 &mut cx.to_async(),
3214 )
3215 .await
3216 .expect("cross-file next_diagnostic_location failed");
3217
3218 let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
3219 let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
3220 buffer
3221 .file()
3222 .expect("buffer should have a file")
3223 .full_path(cx)
3224 });
3225 assert_eq!(
3226 cross_path,
3227 Path::new(path!("root/free_file.txt")),
3228 "should skip collab_file.txt (has collaborator) and pick free_file.txt"
3229 );
3230}
3231
3232#[gpui::test]
3233async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
3234 let (ep_store, _requests) = init_test_with_fake_client(cx);
3235 let fs = FakeFs::new(cx.executor());
3236
3237 // Buffer with two clearly separated regions:
3238 // Region A = lines 0-9 (offsets 0..50)
3239 // Region B = lines 20-29 (offsets 105..155)
3240 // A big gap in between so edits in one region never overlap the other.
3241 let mut content = String::new();
3242 for i in 0..30 {
3243 content.push_str(&format!("line {i:02}\n"));
3244 }
3245
3246 fs.insert_tree(
3247 "/root",
3248 json!({
3249 "foo.md": content.clone()
3250 }),
3251 )
3252 .await;
3253 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3254
3255 let buffer = project
3256 .update(cx, |project, cx| {
3257 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3258 project.open_buffer(path, cx)
3259 })
3260 .await
3261 .unwrap();
3262
3263 type SettledEventRecord = (EditPredictionId, String);
3264 let settled_events: Arc<Mutex<Vec<SettledEventRecord>>> = Arc::new(Mutex::new(Vec::new()));
3265
3266 ep_store.update(cx, |ep_store, cx| {
3267 ep_store.register_buffer(&buffer, &project, cx);
3268
3269 let settled_events = settled_events.clone();
3270 ep_store.settled_event_callback = Some(Box::new(move |id, text| {
3271 settled_events.lock().push((id, text));
3272 }));
3273 });
3274
3275 // --- Phase 1: edit in region A and enqueue prediction A ---
3276
3277 buffer.update(cx, |buffer, cx| {
3278 // Edit at the start of line 0.
3279 buffer.edit(vec![(0..0, "ADDED ")], None, cx);
3280 });
3281 cx.run_until_parked();
3282
3283 let snapshot_a = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3284 let empty_edits: Arc<[(Range<Anchor>, Arc<str>)]> = Vec::new().into();
3285 let edit_preview_a = buffer
3286 .read_with(cx, |buffer, cx| {
3287 buffer.preview_edits(empty_edits.clone(), cx)
3288 })
3289 .await;
3290
3291 // Region A: first 10 lines of the buffer.
3292 let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0));
3293
3294 ep_store.update(cx, |ep_store, cx| {
3295 ep_store.enqueue_settled_prediction(
3296 EditPredictionId("prediction-a".into()),
3297 &project,
3298 &buffer,
3299 &snapshot_a,
3300 editable_region_a.clone(),
3301 &edit_preview_a,
3302 None,
3303 Duration::from_secs(0),
3304 cx,
3305 );
3306 });
3307
3308 // --- Phase 2: repeatedly edit in region A to keep it unsettled ---
3309
3310 // Let the worker process the channel message before we start advancing.
3311 cx.run_until_parked();
3312
3313 let mut region_a_edit_offset = 5;
3314 for _ in 0..3 {
3315 // Edit inside region A (not at the boundary) so `last_edit_at` is
3316 // updated before the worker's next wake.
3317 buffer.update(cx, |buffer, cx| {
3318 buffer.edit(
3319 vec![(region_a_edit_offset..region_a_edit_offset, "x")],
3320 None,
3321 cx,
3322 );
3323 });
3324 region_a_edit_offset += 1;
3325 cx.run_until_parked();
3326
3327 cx.executor()
3328 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 2);
3329 cx.run_until_parked();
3330 assert!(
3331 settled_events.lock().is_empty(),
3332 "no settled events should fire while region A is still being edited"
3333 );
3334 }
3335
3336 // Still nothing settled.
3337 assert!(settled_events.lock().is_empty());
3338
3339 // --- Phase 3: edit in distinct region B, enqueue prediction B ---
3340 // Advance a small amount so B's quiescence window starts later than A's,
3341 // but not so much that A settles (A's last edit was at the start of
3342 // iteration 3, and it needs a full Q to settle).
3343 cx.executor()
3344 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
3345 cx.run_until_parked();
3346 assert!(settled_events.lock().is_empty());
3347
3348 let snapshot_b = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3349 let line_20_offset = snapshot_b.point_to_offset(Point::new(20, 0));
3350
3351 buffer.update(cx, |buffer, cx| {
3352 buffer.edit(vec![(line_20_offset..line_20_offset, "NEW ")], None, cx);
3353 });
3354 cx.run_until_parked();
3355
3356 let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3357 let edit_preview_b = buffer
3358 .read_with(cx, |buffer, cx| buffer.preview_edits(empty_edits, cx))
3359 .await;
3360 let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0));
3361
3362 ep_store.update(cx, |ep_store, cx| {
3363 ep_store.enqueue_settled_prediction(
3364 EditPredictionId("prediction-b".into()),
3365 &project,
3366 &buffer,
3367 &snapshot_b2,
3368 editable_region_b.clone(),
3369 &edit_preview_b,
3370 None,
3371 Duration::from_secs(0),
3372 cx,
3373 );
3374 });
3375
3376 cx.run_until_parked();
3377 assert!(
3378 settled_events.lock().is_empty(),
3379 "neither prediction should have settled yet"
3380 );
3381
3382 // --- Phase 4: let enough time pass for region A to settle ---
3383 // A's last edit was at T_a (during the last loop iteration). The worker is
3384 // sleeping until T_a + Q. We advance just enough to reach that wake time
3385 // (Q/4 since we already advanced Q/4 in phase 3 on top of the loop's
3386 // 3*Q/2). At that point A has been quiet for Q and settles, but B was
3387 // enqueued only Q/4 ago and stays pending.
3388 cx.executor()
3389 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
3390 cx.run_until_parked();
3391
3392 {
3393 let events = settled_events.lock().clone();
3394 assert_eq!(
3395 events.len(),
3396 1,
3397 "prediction and capture_sample for A should have settled, got: {events:?}"
3398 );
3399 assert_eq!(events[0].0, EditPredictionId("prediction-a".into()));
3400 }
3401
3402 // --- Phase 5: let more time pass for region B to settle ---
3403 // B's last edit was Q/4 before A settled. The worker rescheduled to
3404 // B's last_edit_at + Q, which is 3Q/4 from now.
3405 cx.executor()
3406 .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE * 3 / 4);
3407 cx.run_until_parked();
3408
3409 {
3410 let events = settled_events.lock().clone();
3411 assert_eq!(
3412 events.len(),
3413 2,
3414 "both prediction and capture_sample settled events should be emitted for each request, got: {events:?}"
3415 );
3416 assert_eq!(events[1].0, EditPredictionId("prediction-b".into()));
3417 }
3418}
3419
3420#[gpui::test]
3421async fn test_data_collection_disabled_by_default(cx: &mut TestAppContext) {
3422 let (ep_store, _channels) = init_test_with_fake_client(cx);
3423
3424 cx.update(|cx| {
3425 assert!(!ep_store.read(cx).is_data_collection_enabled(cx));
3426 });
3427}
3428
3429#[gpui::test]
3430async fn test_data_collection_enabled_via_legacy_kv_store(cx: &mut TestAppContext) {
3431 let (ep_store, _channels) =
3432 init_test_with_fake_client_and_legacy_data_collection(cx, Some("true"));
3433
3434 cx.update(|cx| {
3435 assert!(ep_store.read(cx).is_data_collection_enabled(cx));
3436 });
3437}
3438
3439#[gpui::test]
3440async fn test_data_collection_default_uses_cached_legacy_value(cx: &mut TestAppContext) {
3441 let (ep_store, _channels) =
3442 init_test_with_fake_client_and_legacy_data_collection(cx, Some("true"));
3443
3444 cx.update(|cx| {
3445 assert!(ep_store.read(cx).is_data_collection_enabled(cx));
3446 });
3447
3448 cx.update(|cx| KeyValueStore::global(cx))
3449 .delete_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into())
3450 .await
3451 .unwrap();
3452
3453 cx.update(|cx| {
3454 assert!(ep_store.read(cx).is_data_collection_enabled(cx));
3455 });
3456}
3457
3458#[gpui::test]
3459async fn test_data_collection_setting_overrides_kv_store(cx: &mut TestAppContext) {
3460 let (ep_store, _channels) =
3461 init_test_with_fake_client_and_legacy_data_collection(cx, Some("true"));
3462
3463 // An explicit false in settings.json wins over the KV store.
3464 cx.update_global::<SettingsStore, _>(|settings, cx| {
3465 settings.update_user_settings(cx, |content| {
3466 content
3467 .project
3468 .all_languages
3469 .edit_predictions
3470 .get_or_insert_default()
3471 .allow_data_collection = Some(EditPredictionDataCollectionChoice::No);
3472 });
3473 });
3474
3475 cx.update(|cx| {
3476 assert!(!ep_store.read(cx).is_data_collection_enabled(cx));
3477 });
3478}
3479
3480#[gpui::test]
3481async fn test_data_collection_enabled_via_setting(cx: &mut TestAppContext) {
3482 let (ep_store, _channels) = init_test_with_fake_client(cx);
3483
3484 cx.update_global::<SettingsStore, _>(|settings, cx| {
3485 settings.update_user_settings(cx, |content| {
3486 content
3487 .project
3488 .all_languages
3489 .edit_predictions
3490 .get_or_insert_default()
3491 .allow_data_collection = Some(EditPredictionDataCollectionChoice::Yes);
3492 });
3493 });
3494
3495 cx.update(|cx| {
3496 assert!(ep_store.read(cx).is_data_collection_enabled(cx));
3497 });
3498}
3499
3500#[gpui::test]
3501async fn test_data_collection_always_enabled_for_staff(cx: &mut TestAppContext) {
3502 let (ep_store, _channels) = init_test_with_fake_client(cx);
3503
3504 cx.update(|cx| {
3505 cx.set_staff(true);
3506 assert!(ep_store.read(cx).is_data_collection_enabled(cx));
3507 });
3508}
3509
3510#[gpui::test]
3511async fn test_data_collection_disabled_by_organization_configuration(cx: &mut TestAppContext) {
3512 let (ep_store, _channels) = init_test_with_fake_client(cx);
3513
3514 cx.update_global::<SettingsStore, _>(|settings, cx| {
3515 settings.update_user_settings(cx, |content| {
3516 content
3517 .project
3518 .all_languages
3519 .edit_predictions
3520 .get_or_insert_default()
3521 .allow_data_collection = Some(EditPredictionDataCollectionChoice::Yes);
3522 });
3523 });
3524
3525 let user_store = cx.update(|cx| ep_store.read(cx).user_store.clone());
3526 cx.update(|cx| {
3527 user_store.update(cx, |user_store, cx| {
3528 user_store.set_current_organization_configuration_for_test(
3529 Arc::new(Organization {
3530 id: OrganizationId("org-1".into()),
3531 name: "Org 1".into(),
3532 is_personal: false,
3533 }),
3534 OrganizationConfiguration {
3535 is_zed_model_provider_enabled: true,
3536 is_agent_thread_feedback_enabled: true,
3537 is_collaboration_enabled: true,
3538 edit_prediction: OrganizationEditPredictionConfiguration {
3539 is_enabled: true,
3540 is_feedback_enabled: false,
3541 },
3542 },
3543 cx,
3544 );
3545 });
3546
3547 assert!(!ep_store.read(cx).is_data_collection_enabled(cx));
3548 });
3549}
3550
3551// When a user had data collection enabled via the legacy KV store (with no explicit
3552// setting in settings.json), toggle_data_collection must read the *resolved* state
3553// (true) and write Some(false).
3554#[gpui::test]
3555async fn test_toggle_data_collection_from_kv_enabled_state(cx: &mut TestAppContext) {
3556 let (ep_store, _channels) =
3557 init_test_with_fake_client_and_legacy_data_collection(cx, Some("true"));
3558
3559 cx.update(|cx| {
3560 assert!(
3561 ep_store.read(cx).is_data_collection_enabled(cx),
3562 "data collection should be enabled via KV store before toggle"
3563 );
3564 });
3565
3566 // Simulate what toggle_data_collection does: capture the resolved current
3567 // state, then write its inverse.
3568 let is_currently_enabled = cx.update(|cx| ep_store.read(cx).is_data_collection_enabled(cx));
3569 cx.update_global::<SettingsStore, _>(|settings, cx| {
3570 settings.update_user_settings(cx, |content| {
3571 content
3572 .project
3573 .all_languages
3574 .edit_predictions
3575 .get_or_insert_default()
3576 .allow_data_collection = Some(if is_currently_enabled {
3577 EditPredictionDataCollectionChoice::No
3578 } else {
3579 EditPredictionDataCollectionChoice::Yes
3580 });
3581 });
3582 });
3583
3584 cx.update(|cx| {
3585 assert!(
3586 !ep_store.read(cx).is_data_collection_enabled(cx),
3587 "data collection should be disabled after toggling off from KV-enabled state"
3588 );
3589 });
3590}
3591
3592#[gpui::test]
3593async fn test_upsell_shown_by_default(cx: &mut TestAppContext) {
3594 init_test(cx);
3595 let kvp = cx.update(|cx| KeyValueStore::global(cx));
3596 kvp.delete_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into())
3597 .await
3598 .ok();
3599 kvp.delete_kvp(ZedPredictUpsell::KEY.into()).await.ok();
3600
3601 cx.update(|cx| assert!(should_show_upsell_modal(cx)));
3602}
3603
3604#[gpui::test]
3605async fn test_upsell_dismissed_when_data_collection_choice_in_kv_store(cx: &mut TestAppContext) {
3606 init_test(cx);
3607
3608 // Any value for the data collection key means the old upsell was already
3609 // shown, regardless of whether data collection was accepted or declined.
3610 for value in &["true", "false"] {
3611 cx.update(|cx| KeyValueStore::global(cx))
3612 .write_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into(), value.to_string())
3613 .await
3614 .unwrap();
3615
3616 cx.update(|cx| {
3617 assert!(
3618 !should_show_upsell_modal(cx),
3619 "upsell should be suppressed when data collection choice is '{value}'"
3620 );
3621 });
3622 }
3623
3624 cx.update(|cx| KeyValueStore::global(cx))
3625 .delete_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into())
3626 .await
3627 .unwrap();
3628}
3629
3630#[gpui::test]
3631async fn test_upsell_dismissed_when_dismissed_key_set(cx: &mut TestAppContext) {
3632 init_test(cx);
3633 let kvp = cx.update(|cx| KeyValueStore::global(cx));
3634 kvp.delete_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into())
3635 .await
3636 .ok();
3637 kvp.write_kvp(ZedPredictUpsell::KEY.into(), "1".into())
3638 .await
3639 .unwrap();
3640
3641 cx.update(|cx| assert!(!should_show_upsell_modal(cx)));
3642
3643 kvp.delete_kvp(ZedPredictUpsell::KEY.into()).await.unwrap();
3644}
3645
3646#[gpui::test]
3647async fn test_upsell_dismissed_via_dismissable_api(cx: &mut TestAppContext) {
3648 init_test(cx);
3649 let kvp = cx.update(|cx| KeyValueStore::global(cx));
3650 kvp.delete_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into())
3651 .await
3652 .ok();
3653 kvp.delete_kvp(ZedPredictUpsell::KEY.into()).await.ok();
3654
3655 cx.update(|cx| {
3656 assert!(should_show_upsell_modal(cx));
3657 ZedPredictUpsell::set_dismissed(true, cx);
3658 });
3659 cx.run_until_parked();
3660
3661 cx.update(|cx| assert!(!should_show_upsell_modal(cx)));
3662
3663 kvp.delete_kvp(ZedPredictUpsell::KEY.into()).await.unwrap();
3664}
3665
3666#[ctor::ctor]
3667fn init_logger() {
3668 zlog::init_test();
3669}