1use super::*;
2use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
3use client::{UserStore, test::FakeServer};
4use clock::{FakeSystemClock, ReplicaId};
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
8 RejectEditPredictionsBody,
9};
10use futures::{
11 AsyncReadExt, StreamExt,
12 channel::{mpsc, oneshot},
13};
14use gpui::{
15 Entity, TestAppContext,
16 http_client::{FakeHttpClient, Response},
17};
18use indoc::indoc;
19use language::{Point, ToOffset as _};
20use lsp::LanguageServerId;
21use open_ai::Usage;
22use parking_lot::Mutex;
23use pretty_assertions::{assert_eq, assert_matches};
24use project::{FakeFs, Project};
25use serde_json::json;
26use settings::SettingsStore;
27use std::{path::Path, sync::Arc, time::Duration};
28use util::{path, rel_path::rel_path};
29use uuid::Uuid;
30use zeta_prompt::ZetaPromptInput;
31
32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
33
34#[gpui::test]
35async fn test_current_state(cx: &mut TestAppContext) {
36 let (ep_store, mut requests) = init_test_with_fake_client(cx);
37 let fs = FakeFs::new(cx.executor());
38 fs.insert_tree(
39 "/root",
40 json!({
41 "1.txt": "Hello!\nHow\nBye\n",
42 "2.txt": "Hola!\nComo\nAdios\n"
43 }),
44 )
45 .await;
46 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
47
48 let buffer1 = project
49 .update(cx, |project, cx| {
50 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
51 project.set_active_path(Some(path.clone()), cx);
52 project.open_buffer(path, cx)
53 })
54 .await
55 .unwrap();
56 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
57 let position = snapshot1.anchor_before(language::Point::new(1, 3));
58
59 ep_store.update(cx, |ep_store, cx| {
60 ep_store.register_project(&project, cx);
61 ep_store.register_buffer(&buffer1, &project, cx);
62 });
63
64 // Prediction for current file
65
66 ep_store.update(cx, |ep_store, cx| {
67 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
68 });
69 let (request, respond_tx) = requests.predict.next().await.unwrap();
70
71 respond_tx
72 .send(model_response(
73 request,
74 indoc! {r"
75 --- a/root/1.txt
76 +++ b/root/1.txt
77 @@ ... @@
78 Hello!
79 -How
80 +How are you?
81 Bye
82 "},
83 ))
84 .unwrap();
85
86 cx.run_until_parked();
87
88 ep_store.update(cx, |ep_store, cx| {
89 let prediction = ep_store
90 .prediction_at(&buffer1, None, &project, cx)
91 .unwrap();
92 assert_matches!(prediction, BufferEditPrediction::Local { .. });
93 });
94
95 ep_store.update(cx, |ep_store, _cx| {
96 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
97 });
98
99 // Prediction for diagnostic in another file
100
101 let diagnostic = lsp::Diagnostic {
102 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
103 severity: Some(lsp::DiagnosticSeverity::ERROR),
104 message: "Sentence is incomplete".to_string(),
105 ..Default::default()
106 };
107
108 project.update(cx, |project, cx| {
109 project.lsp_store().update(cx, |lsp_store, cx| {
110 lsp_store
111 .update_diagnostics(
112 LanguageServerId(0),
113 lsp::PublishDiagnosticsParams {
114 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
115 diagnostics: vec![diagnostic],
116 version: None,
117 },
118 None,
119 language::DiagnosticSourceKind::Pushed,
120 &[],
121 cx,
122 )
123 .unwrap();
124 });
125 });
126
127 let (request, respond_tx) = requests.predict.next().await.unwrap();
128 respond_tx
129 .send(model_response(
130 request,
131 indoc! {r#"
132 --- a/root/2.txt
133 +++ b/root/2.txt
134 @@ ... @@
135 Hola!
136 -Como
137 +Como estas?
138 Adios
139 "#},
140 ))
141 .unwrap();
142 cx.run_until_parked();
143
144 ep_store.update(cx, |ep_store, cx| {
145 let prediction = ep_store
146 .prediction_at(&buffer1, None, &project, cx)
147 .unwrap();
148 assert_matches!(
149 prediction,
150 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
151 );
152 });
153
154 let buffer2 = project
155 .update(cx, |project, cx| {
156 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
157 project.open_buffer(path, cx)
158 })
159 .await
160 .unwrap();
161
162 ep_store.update(cx, |ep_store, cx| {
163 let prediction = ep_store
164 .prediction_at(&buffer2, None, &project, cx)
165 .unwrap();
166 assert_matches!(prediction, BufferEditPrediction::Local { .. });
167 });
168}
169
170#[gpui::test]
171async fn test_simple_request(cx: &mut TestAppContext) {
172 let (ep_store, mut requests) = init_test_with_fake_client(cx);
173 let fs = FakeFs::new(cx.executor());
174 fs.insert_tree(
175 "/root",
176 json!({
177 "foo.md": "Hello!\nHow\nBye\n"
178 }),
179 )
180 .await;
181 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
182
183 let buffer = project
184 .update(cx, |project, cx| {
185 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
186 project.open_buffer(path, cx)
187 })
188 .await
189 .unwrap();
190 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
191 let position = snapshot.anchor_before(language::Point::new(1, 3));
192
193 let prediction_task = ep_store.update(cx, |ep_store, cx| {
194 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
195 });
196
197 let (request, respond_tx) = requests.predict.next().await.unwrap();
198
199 // TODO Put back when we have a structured request again
200 // assert_eq!(
201 // request.excerpt_path.as_ref(),
202 // Path::new(path!("root/foo.md"))
203 // );
204 // assert_eq!(
205 // request.cursor_point,
206 // Point {
207 // line: Line(1),
208 // column: 3
209 // }
210 // );
211
212 respond_tx
213 .send(model_response(
214 request,
215 indoc! { r"
216 --- a/root/foo.md
217 +++ b/root/foo.md
218 @@ ... @@
219 Hello!
220 -How
221 +How are you?
222 Bye
223 "},
224 ))
225 .unwrap();
226
227 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
228
229 assert_eq!(prediction.edits.len(), 1);
230 assert_eq!(
231 prediction.edits[0].0.to_point(&snapshot).start,
232 language::Point::new(1, 3)
233 );
234 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
235}
236
237#[gpui::test]
238async fn test_request_events(cx: &mut TestAppContext) {
239 let (ep_store, mut requests) = init_test_with_fake_client(cx);
240 let fs = FakeFs::new(cx.executor());
241 fs.insert_tree(
242 "/root",
243 json!({
244 "foo.md": "Hello!\n\nBye\n"
245 }),
246 )
247 .await;
248 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
249
250 let buffer = project
251 .update(cx, |project, cx| {
252 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
253 project.open_buffer(path, cx)
254 })
255 .await
256 .unwrap();
257
258 ep_store.update(cx, |ep_store, cx| {
259 ep_store.register_buffer(&buffer, &project, cx);
260 });
261
262 buffer.update(cx, |buffer, cx| {
263 buffer.edit(vec![(7..7, "How")], None, cx);
264 });
265
266 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
267 let position = snapshot.anchor_before(language::Point::new(1, 3));
268
269 let prediction_task = ep_store.update(cx, |ep_store, cx| {
270 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
271 });
272
273 let (request, respond_tx) = requests.predict.next().await.unwrap();
274
275 let prompt = prompt_from_request(&request);
276 assert!(
277 prompt.contains(indoc! {"
278 --- a/root/foo.md
279 +++ b/root/foo.md
280 @@ -1,3 +1,3 @@
281 Hello!
282 -
283 +How
284 Bye
285 "}),
286 "{prompt}"
287 );
288
289 respond_tx
290 .send(model_response(
291 request,
292 indoc! {r#"
293 --- a/root/foo.md
294 +++ b/root/foo.md
295 @@ ... @@
296 Hello!
297 -How
298 +How are you?
299 Bye
300 "#},
301 ))
302 .unwrap();
303
304 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
305
306 assert_eq!(prediction.edits.len(), 1);
307 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
308}
309
310#[gpui::test]
311async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
312 let (ep_store, _requests) = init_test_with_fake_client(cx);
313 let fs = FakeFs::new(cx.executor());
314 fs.insert_tree(
315 "/root",
316 json!({
317 "foo.md": "Hello!\n\nBye\n"
318 }),
319 )
320 .await;
321 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
322
323 let buffer = project
324 .update(cx, |project, cx| {
325 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
326 project.open_buffer(path, cx)
327 })
328 .await
329 .unwrap();
330
331 ep_store.update(cx, |ep_store, cx| {
332 ep_store.register_buffer(&buffer, &project, cx);
333 });
334
335 // First burst: insert "How"
336 buffer.update(cx, |buffer, cx| {
337 buffer.edit(vec![(7..7, "How")], None, cx);
338 });
339
340 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
341 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
342 cx.run_until_parked();
343
344 // Second burst: append " are you?" immediately after "How" on the same line.
345 //
346 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
347 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
348 buffer.update(cx, |buffer, cx| {
349 buffer.edit(vec![(10..10, " are you?")], None, cx);
350 });
351
352 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
353 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
354 buffer.update(cx, |buffer, cx| {
355 buffer.edit(vec![(19..19, "!")], None, cx);
356 });
357
358 // Without time-based splitting, there is one event.
359 let events = ep_store.update(cx, |ep_store, cx| {
360 ep_store.edit_history_for_project(&project, cx)
361 });
362 assert_eq!(events.len(), 1);
363 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].as_ref();
364 assert_eq!(
365 diff.as_str(),
366 indoc! {"
367 @@ -1,3 +1,3 @@
368 Hello!
369 -
370 +How are you?!
371 Bye
372 "}
373 );
374
375 // With time-based splitting, there are two distinct events.
376 let events = ep_store.update(cx, |ep_store, cx| {
377 ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
378 });
379 assert_eq!(events.len(), 2);
380 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].as_ref();
381 assert_eq!(
382 diff.as_str(),
383 indoc! {"
384 @@ -1,3 +1,3 @@
385 Hello!
386 -
387 +How
388 Bye
389 "}
390 );
391
392 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].as_ref();
393 assert_eq!(
394 diff.as_str(),
395 indoc! {"
396 @@ -1,3 +1,3 @@
397 Hello!
398 -How
399 +How are you?!
400 Bye
401 "}
402 );
403}
404
405#[gpui::test]
406async fn test_empty_prediction(cx: &mut TestAppContext) {
407 let (ep_store, mut requests) = init_test_with_fake_client(cx);
408 let fs = FakeFs::new(cx.executor());
409 fs.insert_tree(
410 "/root",
411 json!({
412 "foo.md": "Hello!\nHow\nBye\n"
413 }),
414 )
415 .await;
416 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
417
418 let buffer = project
419 .update(cx, |project, cx| {
420 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
421 project.open_buffer(path, cx)
422 })
423 .await
424 .unwrap();
425 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
426 let position = snapshot.anchor_before(language::Point::new(1, 3));
427
428 ep_store.update(cx, |ep_store, cx| {
429 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
430 });
431
432 let (request, respond_tx) = requests.predict.next().await.unwrap();
433 let response = model_response(request, "");
434 let id = response.id.clone();
435 respond_tx.send(response).unwrap();
436
437 cx.run_until_parked();
438
439 ep_store.update(cx, |ep_store, cx| {
440 assert!(
441 ep_store
442 .prediction_at(&buffer, None, &project, cx)
443 .is_none()
444 );
445 });
446
447 // prediction is reported as rejected
448 let (reject_request, _) = requests.reject.next().await.unwrap();
449
450 assert_eq!(
451 &reject_request.rejections,
452 &[EditPredictionRejection {
453 request_id: id,
454 reason: EditPredictionRejectReason::Empty,
455 was_shown: false
456 }]
457 );
458}
459
460#[gpui::test]
461async fn test_interpolated_empty(cx: &mut TestAppContext) {
462 let (ep_store, mut requests) = init_test_with_fake_client(cx);
463 let fs = FakeFs::new(cx.executor());
464 fs.insert_tree(
465 "/root",
466 json!({
467 "foo.md": "Hello!\nHow\nBye\n"
468 }),
469 )
470 .await;
471 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
472
473 let buffer = project
474 .update(cx, |project, cx| {
475 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
476 project.open_buffer(path, cx)
477 })
478 .await
479 .unwrap();
480 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
481 let position = snapshot.anchor_before(language::Point::new(1, 3));
482
483 ep_store.update(cx, |ep_store, cx| {
484 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
485 });
486
487 let (request, respond_tx) = requests.predict.next().await.unwrap();
488
489 buffer.update(cx, |buffer, cx| {
490 buffer.set_text("Hello!\nHow are you?\nBye", cx);
491 });
492
493 let response = model_response(request, SIMPLE_DIFF);
494 let id = response.id.clone();
495 respond_tx.send(response).unwrap();
496
497 cx.run_until_parked();
498
499 ep_store.update(cx, |ep_store, cx| {
500 assert!(
501 ep_store
502 .prediction_at(&buffer, None, &project, cx)
503 .is_none()
504 );
505 });
506
507 // prediction is reported as rejected
508 let (reject_request, _) = requests.reject.next().await.unwrap();
509
510 assert_eq!(
511 &reject_request.rejections,
512 &[EditPredictionRejection {
513 request_id: id,
514 reason: EditPredictionRejectReason::InterpolatedEmpty,
515 was_shown: false
516 }]
517 );
518}
519
520const SIMPLE_DIFF: &str = indoc! { r"
521 --- a/root/foo.md
522 +++ b/root/foo.md
523 @@ ... @@
524 Hello!
525 -How
526 +How are you?
527 Bye
528"};
529
530#[gpui::test]
531async fn test_replace_current(cx: &mut TestAppContext) {
532 let (ep_store, mut requests) = init_test_with_fake_client(cx);
533 let fs = FakeFs::new(cx.executor());
534 fs.insert_tree(
535 "/root",
536 json!({
537 "foo.md": "Hello!\nHow\nBye\n"
538 }),
539 )
540 .await;
541 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
542
543 let buffer = project
544 .update(cx, |project, cx| {
545 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
546 project.open_buffer(path, cx)
547 })
548 .await
549 .unwrap();
550 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
551 let position = snapshot.anchor_before(language::Point::new(1, 3));
552
553 ep_store.update(cx, |ep_store, cx| {
554 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
555 });
556
557 let (request, respond_tx) = requests.predict.next().await.unwrap();
558 let first_response = model_response(request, SIMPLE_DIFF);
559 let first_id = first_response.id.clone();
560 respond_tx.send(first_response).unwrap();
561
562 cx.run_until_parked();
563
564 ep_store.update(cx, |ep_store, cx| {
565 assert_eq!(
566 ep_store
567 .prediction_at(&buffer, None, &project, cx)
568 .unwrap()
569 .id
570 .0,
571 first_id
572 );
573 });
574
575 // a second request is triggered
576 ep_store.update(cx, |ep_store, cx| {
577 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
578 });
579
580 let (request, respond_tx) = requests.predict.next().await.unwrap();
581 let second_response = model_response(request, SIMPLE_DIFF);
582 let second_id = second_response.id.clone();
583 respond_tx.send(second_response).unwrap();
584
585 cx.run_until_parked();
586
587 ep_store.update(cx, |ep_store, cx| {
588 // second replaces first
589 assert_eq!(
590 ep_store
591 .prediction_at(&buffer, None, &project, cx)
592 .unwrap()
593 .id
594 .0,
595 second_id
596 );
597 });
598
599 // first is reported as replaced
600 let (reject_request, _) = requests.reject.next().await.unwrap();
601
602 assert_eq!(
603 &reject_request.rejections,
604 &[EditPredictionRejection {
605 request_id: first_id,
606 reason: EditPredictionRejectReason::Replaced,
607 was_shown: false
608 }]
609 );
610}
611
612#[gpui::test]
613async fn test_current_preferred(cx: &mut TestAppContext) {
614 let (ep_store, mut requests) = init_test_with_fake_client(cx);
615 let fs = FakeFs::new(cx.executor());
616 fs.insert_tree(
617 "/root",
618 json!({
619 "foo.md": "Hello!\nHow\nBye\n"
620 }),
621 )
622 .await;
623 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
624
625 let buffer = project
626 .update(cx, |project, cx| {
627 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
628 project.open_buffer(path, cx)
629 })
630 .await
631 .unwrap();
632 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
633 let position = snapshot.anchor_before(language::Point::new(1, 3));
634
635 ep_store.update(cx, |ep_store, cx| {
636 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
637 });
638
639 let (request, respond_tx) = requests.predict.next().await.unwrap();
640 let first_response = model_response(request, SIMPLE_DIFF);
641 let first_id = first_response.id.clone();
642 respond_tx.send(first_response).unwrap();
643
644 cx.run_until_parked();
645
646 ep_store.update(cx, |ep_store, cx| {
647 assert_eq!(
648 ep_store
649 .prediction_at(&buffer, None, &project, cx)
650 .unwrap()
651 .id
652 .0,
653 first_id
654 );
655 });
656
657 // a second request is triggered
658 ep_store.update(cx, |ep_store, cx| {
659 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
660 });
661
662 let (request, respond_tx) = requests.predict.next().await.unwrap();
663 // worse than current prediction
664 let second_response = model_response(
665 request,
666 indoc! { r"
667 --- a/root/foo.md
668 +++ b/root/foo.md
669 @@ ... @@
670 Hello!
671 -How
672 +How are
673 Bye
674 "},
675 );
676 let second_id = second_response.id.clone();
677 respond_tx.send(second_response).unwrap();
678
679 cx.run_until_parked();
680
681 ep_store.update(cx, |ep_store, cx| {
682 // first is preferred over second
683 assert_eq!(
684 ep_store
685 .prediction_at(&buffer, None, &project, cx)
686 .unwrap()
687 .id
688 .0,
689 first_id
690 );
691 });
692
693 // second is reported as rejected
694 let (reject_request, _) = requests.reject.next().await.unwrap();
695
696 assert_eq!(
697 &reject_request.rejections,
698 &[EditPredictionRejection {
699 request_id: second_id,
700 reason: EditPredictionRejectReason::CurrentPreferred,
701 was_shown: false
702 }]
703 );
704}
705
706#[gpui::test]
707async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
708 let (ep_store, mut requests) = init_test_with_fake_client(cx);
709 let fs = FakeFs::new(cx.executor());
710 fs.insert_tree(
711 "/root",
712 json!({
713 "foo.md": "Hello!\nHow\nBye\n"
714 }),
715 )
716 .await;
717 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
718
719 let buffer = project
720 .update(cx, |project, cx| {
721 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
722 project.open_buffer(path, cx)
723 })
724 .await
725 .unwrap();
726 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
727 let position = snapshot.anchor_before(language::Point::new(1, 3));
728
729 // start two refresh tasks
730 ep_store.update(cx, |ep_store, cx| {
731 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
732 });
733
734 let (request1, respond_first) = requests.predict.next().await.unwrap();
735
736 ep_store.update(cx, |ep_store, cx| {
737 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
738 });
739
740 let (request, respond_second) = requests.predict.next().await.unwrap();
741
742 // wait for throttle
743 cx.run_until_parked();
744
745 // second responds first
746 let second_response = model_response(request, SIMPLE_DIFF);
747 let second_id = second_response.id.clone();
748 respond_second.send(second_response).unwrap();
749
750 cx.run_until_parked();
751
752 ep_store.update(cx, |ep_store, cx| {
753 // current prediction is second
754 assert_eq!(
755 ep_store
756 .prediction_at(&buffer, None, &project, cx)
757 .unwrap()
758 .id
759 .0,
760 second_id
761 );
762 });
763
764 let first_response = model_response(request1, SIMPLE_DIFF);
765 let first_id = first_response.id.clone();
766 respond_first.send(first_response).unwrap();
767
768 cx.run_until_parked();
769
770 ep_store.update(cx, |ep_store, cx| {
771 // current prediction is still second, since first was cancelled
772 assert_eq!(
773 ep_store
774 .prediction_at(&buffer, None, &project, cx)
775 .unwrap()
776 .id
777 .0,
778 second_id
779 );
780 });
781
782 // first is reported as rejected
783 let (reject_request, _) = requests.reject.next().await.unwrap();
784
785 cx.run_until_parked();
786
787 assert_eq!(
788 &reject_request.rejections,
789 &[EditPredictionRejection {
790 request_id: first_id,
791 reason: EditPredictionRejectReason::Canceled,
792 was_shown: false
793 }]
794 );
795}
796
797#[gpui::test]
798async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
799 let (ep_store, mut requests) = init_test_with_fake_client(cx);
800 let fs = FakeFs::new(cx.executor());
801 fs.insert_tree(
802 "/root",
803 json!({
804 "foo.md": "Hello!\nHow\nBye\n"
805 }),
806 )
807 .await;
808 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
809
810 let buffer = project
811 .update(cx, |project, cx| {
812 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
813 project.open_buffer(path, cx)
814 })
815 .await
816 .unwrap();
817 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
818 let position = snapshot.anchor_before(language::Point::new(1, 3));
819
820 // start two refresh tasks
821 ep_store.update(cx, |ep_store, cx| {
822 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
823 });
824
825 let (request1, respond_first) = requests.predict.next().await.unwrap();
826
827 ep_store.update(cx, |ep_store, cx| {
828 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
829 });
830
831 let (request2, respond_second) = requests.predict.next().await.unwrap();
832
833 // wait for throttle, so requests are sent
834 cx.run_until_parked();
835
836 ep_store.update(cx, |ep_store, cx| {
837 // start a third request
838 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
839
840 // 2 are pending, so 2nd is cancelled
841 assert_eq!(
842 ep_store
843 .get_or_init_project(&project, cx)
844 .cancelled_predictions
845 .iter()
846 .copied()
847 .collect::<Vec<_>>(),
848 [1]
849 );
850 });
851
852 // wait for throttle
853 cx.run_until_parked();
854
855 let (request3, respond_third) = requests.predict.next().await.unwrap();
856
857 let first_response = model_response(request1, SIMPLE_DIFF);
858 let first_id = first_response.id.clone();
859 respond_first.send(first_response).unwrap();
860
861 cx.run_until_parked();
862
863 ep_store.update(cx, |ep_store, cx| {
864 // current prediction is first
865 assert_eq!(
866 ep_store
867 .prediction_at(&buffer, None, &project, cx)
868 .unwrap()
869 .id
870 .0,
871 first_id
872 );
873 });
874
875 let cancelled_response = model_response(request2, SIMPLE_DIFF);
876 let cancelled_id = cancelled_response.id.clone();
877 respond_second.send(cancelled_response).unwrap();
878
879 cx.run_until_parked();
880
881 ep_store.update(cx, |ep_store, cx| {
882 // current prediction is still first, since second was cancelled
883 assert_eq!(
884 ep_store
885 .prediction_at(&buffer, None, &project, cx)
886 .unwrap()
887 .id
888 .0,
889 first_id
890 );
891 });
892
893 let third_response = model_response(request3, SIMPLE_DIFF);
894 let third_response_id = third_response.id.clone();
895 respond_third.send(third_response).unwrap();
896
897 cx.run_until_parked();
898
899 ep_store.update(cx, |ep_store, cx| {
900 // third completes and replaces first
901 assert_eq!(
902 ep_store
903 .prediction_at(&buffer, None, &project, cx)
904 .unwrap()
905 .id
906 .0,
907 third_response_id
908 );
909 });
910
911 // second is reported as rejected
912 let (reject_request, _) = requests.reject.next().await.unwrap();
913
914 cx.run_until_parked();
915
916 assert_eq!(
917 &reject_request.rejections,
918 &[
919 EditPredictionRejection {
920 request_id: cancelled_id,
921 reason: EditPredictionRejectReason::Canceled,
922 was_shown: false
923 },
924 EditPredictionRejection {
925 request_id: first_id,
926 reason: EditPredictionRejectReason::Replaced,
927 was_shown: false
928 }
929 ]
930 );
931}
932
933#[gpui::test]
934async fn test_rejections_flushing(cx: &mut TestAppContext) {
935 let (ep_store, mut requests) = init_test_with_fake_client(cx);
936
937 ep_store.update(cx, |ep_store, _cx| {
938 ep_store.reject_prediction(
939 EditPredictionId("test-1".into()),
940 EditPredictionRejectReason::Discarded,
941 false,
942 );
943 ep_store.reject_prediction(
944 EditPredictionId("test-2".into()),
945 EditPredictionRejectReason::Canceled,
946 true,
947 );
948 });
949
950 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
951 cx.run_until_parked();
952
953 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
954 respond_tx.send(()).unwrap();
955
956 // batched
957 assert_eq!(reject_request.rejections.len(), 2);
958 assert_eq!(
959 reject_request.rejections[0],
960 EditPredictionRejection {
961 request_id: "test-1".to_string(),
962 reason: EditPredictionRejectReason::Discarded,
963 was_shown: false
964 }
965 );
966 assert_eq!(
967 reject_request.rejections[1],
968 EditPredictionRejection {
969 request_id: "test-2".to_string(),
970 reason: EditPredictionRejectReason::Canceled,
971 was_shown: true
972 }
973 );
974
975 // Reaching batch size limit sends without debounce
976 ep_store.update(cx, |ep_store, _cx| {
977 for i in 0..70 {
978 ep_store.reject_prediction(
979 EditPredictionId(format!("batch-{}", i).into()),
980 EditPredictionRejectReason::Discarded,
981 false,
982 );
983 }
984 });
985
986 // First MAX/2 items are sent immediately
987 cx.run_until_parked();
988 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
989 respond_tx.send(()).unwrap();
990
991 assert_eq!(reject_request.rejections.len(), 50);
992 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
993 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
994
995 // Remaining items are debounced with the next batch
996 cx.executor().advance_clock(Duration::from_secs(15));
997 cx.run_until_parked();
998
999 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1000 respond_tx.send(()).unwrap();
1001
1002 assert_eq!(reject_request.rejections.len(), 20);
1003 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1004 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1005
1006 // Request failure
1007 ep_store.update(cx, |ep_store, _cx| {
1008 ep_store.reject_prediction(
1009 EditPredictionId("retry-1".into()),
1010 EditPredictionRejectReason::Discarded,
1011 false,
1012 );
1013 });
1014
1015 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1016 cx.run_until_parked();
1017
1018 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1019 assert_eq!(reject_request.rejections.len(), 1);
1020 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1021 // Simulate failure
1022 drop(_respond_tx);
1023
1024 // Add another rejection
1025 ep_store.update(cx, |ep_store, _cx| {
1026 ep_store.reject_prediction(
1027 EditPredictionId("retry-2".into()),
1028 EditPredictionRejectReason::Discarded,
1029 false,
1030 );
1031 });
1032
1033 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1034 cx.run_until_parked();
1035
1036 // Retry should include both the failed item and the new one
1037 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1038 respond_tx.send(()).unwrap();
1039
1040 assert_eq!(reject_request.rejections.len(), 2);
1041 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1042 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1043}
1044
1045// Skipped until we start including diagnostics in prompt
1046// #[gpui::test]
1047// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1048// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1049// let fs = FakeFs::new(cx.executor());
1050// fs.insert_tree(
1051// "/root",
1052// json!({
1053// "foo.md": "Hello!\nBye"
1054// }),
1055// )
1056// .await;
1057// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1058
1059// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1060// let diagnostic = lsp::Diagnostic {
1061// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1062// severity: Some(lsp::DiagnosticSeverity::ERROR),
1063// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1064// ..Default::default()
1065// };
1066
1067// project.update(cx, |project, cx| {
1068// project.lsp_store().update(cx, |lsp_store, cx| {
1069// // Create some diagnostics
1070// lsp_store
1071// .update_diagnostics(
1072// LanguageServerId(0),
1073// lsp::PublishDiagnosticsParams {
1074// uri: path_to_buffer_uri.clone(),
1075// diagnostics: vec![diagnostic],
1076// version: None,
1077// },
1078// None,
1079// language::DiagnosticSourceKind::Pushed,
1080// &[],
1081// cx,
1082// )
1083// .unwrap();
1084// });
1085// });
1086
1087// let buffer = project
1088// .update(cx, |project, cx| {
1089// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1090// project.open_buffer(path, cx)
1091// })
1092// .await
1093// .unwrap();
1094
1095// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1096// let position = snapshot.anchor_before(language::Point::new(0, 0));
1097
1098// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1099// ep_store.request_prediction(&project, &buffer, position, cx)
1100// });
1101
1102// let (request, _respond_tx) = req_rx.next().await.unwrap();
1103
1104// assert_eq!(request.diagnostic_groups.len(), 1);
1105// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1106// .unwrap();
1107// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1108// assert_eq!(
1109// value,
1110// json!({
1111// "entries": [{
1112// "range": {
1113// "start": 8,
1114// "end": 10
1115// },
1116// "diagnostic": {
1117// "source": null,
1118// "code": null,
1119// "code_description": null,
1120// "severity": 1,
1121// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1122// "markdown": null,
1123// "group_id": 0,
1124// "is_primary": true,
1125// "is_disk_based": false,
1126// "is_unnecessary": false,
1127// "source_kind": "Pushed",
1128// "data": null,
1129// "underline": true
1130// }
1131// }],
1132// "primary_ix": 0
1133// })
1134// );
1135// }
1136
1137// Generate a model response that would apply the given diff to the active file.
1138fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
1139 let prompt = match &request.messages[0] {
1140 open_ai::RequestMessage::User {
1141 content: open_ai::MessageContent::Plain(content),
1142 } => content,
1143 _ => panic!("unexpected request {request:?}"),
1144 };
1145
1146 let open = "<editable_region>\n";
1147 let close = "</editable_region>";
1148 let cursor = "<|user_cursor|>";
1149
1150 let start_ix = open.len() + prompt.find(open).unwrap();
1151 let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
1152 let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
1153 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1154
1155 open_ai::Response {
1156 id: Uuid::new_v4().to_string(),
1157 object: "response".into(),
1158 created: 0,
1159 model: "model".into(),
1160 choices: vec![open_ai::Choice {
1161 index: 0,
1162 message: open_ai::RequestMessage::Assistant {
1163 content: Some(open_ai::MessageContent::Plain(new_excerpt)),
1164 tool_calls: vec![],
1165 },
1166 finish_reason: None,
1167 }],
1168 usage: Usage {
1169 prompt_tokens: 0,
1170 completion_tokens: 0,
1171 total_tokens: 0,
1172 },
1173 }
1174}
1175
1176fn prompt_from_request(request: &open_ai::Request) -> &str {
1177 assert_eq!(request.messages.len(), 1);
1178 let open_ai::RequestMessage::User {
1179 content: open_ai::MessageContent::Plain(content),
1180 ..
1181 } = &request.messages[0]
1182 else {
1183 panic!(
1184 "Request does not have single user message of type Plain. {:#?}",
1185 request
1186 );
1187 };
1188 content
1189}
1190
1191struct RequestChannels {
1192 predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
1193 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1194}
1195
1196fn init_test_with_fake_client(
1197 cx: &mut TestAppContext,
1198) -> (Entity<EditPredictionStore>, RequestChannels) {
1199 cx.update(move |cx| {
1200 let settings_store = SettingsStore::test(cx);
1201 cx.set_global(settings_store);
1202 zlog::init_test();
1203
1204 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1205 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1206
1207 let http_client = FakeHttpClient::create({
1208 move |req| {
1209 let uri = req.uri().path().to_string();
1210 let mut body = req.into_body();
1211 let predict_req_tx = predict_req_tx.clone();
1212 let reject_req_tx = reject_req_tx.clone();
1213 async move {
1214 let resp = match uri.as_str() {
1215 "/client/llm_tokens" => serde_json::to_string(&json!({
1216 "token": "test"
1217 }))
1218 .unwrap(),
1219 "/predict_edits/raw" => {
1220 let mut buf = Vec::new();
1221 body.read_to_end(&mut buf).await.ok();
1222 let req = serde_json::from_slice(&buf).unwrap();
1223
1224 let (res_tx, res_rx) = oneshot::channel();
1225 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1226 serde_json::to_string(&res_rx.await?).unwrap()
1227 }
1228 "/predict_edits/reject" => {
1229 let mut buf = Vec::new();
1230 body.read_to_end(&mut buf).await.ok();
1231 let req = serde_json::from_slice(&buf).unwrap();
1232
1233 let (res_tx, res_rx) = oneshot::channel();
1234 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1235 serde_json::to_string(&res_rx.await?).unwrap()
1236 }
1237 _ => {
1238 panic!("Unexpected path: {}", uri)
1239 }
1240 };
1241
1242 Ok(Response::builder().body(resp.into()).unwrap())
1243 }
1244 }
1245 });
1246
1247 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1248 client.cloud_client().set_credentials(1, "test".into());
1249
1250 language_model::init(client.clone(), cx);
1251
1252 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1253 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1254
1255 (
1256 ep_store,
1257 RequestChannels {
1258 predict: predict_req_rx,
1259 reject: reject_req_rx,
1260 },
1261 )
1262 })
1263}
1264
1265const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1266
1267#[gpui::test]
1268async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1269 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1270 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1271 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1272 });
1273
1274 let edit_preview = cx
1275 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1276 .await;
1277
1278 let prediction = EditPrediction {
1279 edits,
1280 edit_preview,
1281 buffer: buffer.clone(),
1282 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1283 id: EditPredictionId("the-id".into()),
1284 inputs: ZetaPromptInput {
1285 events: Default::default(),
1286 related_files: Default::default(),
1287 cursor_path: Path::new("").into(),
1288 cursor_excerpt: "".into(),
1289 editable_range_in_excerpt: 0..0,
1290 cursor_offset_in_excerpt: 0,
1291 },
1292 buffer_snapshotted_at: Instant::now(),
1293 response_received_at: Instant::now(),
1294 };
1295
1296 cx.update(|cx| {
1297 assert_eq!(
1298 from_completion_edits(
1299 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1300 &buffer,
1301 cx
1302 ),
1303 vec![(2..5, "REM".into()), (9..11, "".into())]
1304 );
1305
1306 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1307 assert_eq!(
1308 from_completion_edits(
1309 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1310 &buffer,
1311 cx
1312 ),
1313 vec![(2..2, "REM".into()), (6..8, "".into())]
1314 );
1315
1316 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1317 assert_eq!(
1318 from_completion_edits(
1319 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1320 &buffer,
1321 cx
1322 ),
1323 vec![(2..5, "REM".into()), (9..11, "".into())]
1324 );
1325
1326 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1327 assert_eq!(
1328 from_completion_edits(
1329 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1330 &buffer,
1331 cx
1332 ),
1333 vec![(3..3, "EM".into()), (7..9, "".into())]
1334 );
1335
1336 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1337 assert_eq!(
1338 from_completion_edits(
1339 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1340 &buffer,
1341 cx
1342 ),
1343 vec![(4..4, "M".into()), (8..10, "".into())]
1344 );
1345
1346 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1347 assert_eq!(
1348 from_completion_edits(
1349 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1350 &buffer,
1351 cx
1352 ),
1353 vec![(9..11, "".into())]
1354 );
1355
1356 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1357 assert_eq!(
1358 from_completion_edits(
1359 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1360 &buffer,
1361 cx
1362 ),
1363 vec![(4..4, "M".into()), (8..10, "".into())]
1364 );
1365
1366 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1367 assert_eq!(
1368 from_completion_edits(
1369 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1370 &buffer,
1371 cx
1372 ),
1373 vec![(4..4, "M".into())]
1374 );
1375
1376 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1377 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1378 })
1379}
1380
1381#[gpui::test]
1382async fn test_clean_up_diff(cx: &mut TestAppContext) {
1383 init_test(cx);
1384
1385 assert_eq!(
1386 apply_edit_prediction(
1387 indoc! {"
1388 fn main() {
1389 let word_1 = \"lorem\";
1390 let range = word.len()..word.len();
1391 }
1392 "},
1393 indoc! {"
1394 <|editable_region_start|>
1395 fn main() {
1396 let word_1 = \"lorem\";
1397 let range = word_1.len()..word_1.len();
1398 }
1399
1400 <|editable_region_end|>
1401 "},
1402 cx,
1403 )
1404 .await,
1405 indoc! {"
1406 fn main() {
1407 let word_1 = \"lorem\";
1408 let range = word_1.len()..word_1.len();
1409 }
1410 "},
1411 );
1412
1413 assert_eq!(
1414 apply_edit_prediction(
1415 indoc! {"
1416 fn main() {
1417 let story = \"the quick\"
1418 }
1419 "},
1420 indoc! {"
1421 <|editable_region_start|>
1422 fn main() {
1423 let story = \"the quick brown fox jumps over the lazy dog\";
1424 }
1425
1426 <|editable_region_end|>
1427 "},
1428 cx,
1429 )
1430 .await,
1431 indoc! {"
1432 fn main() {
1433 let story = \"the quick brown fox jumps over the lazy dog\";
1434 }
1435 "},
1436 );
1437}
1438
1439#[gpui::test]
1440async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1441 init_test(cx);
1442
1443 let buffer_content = "lorem\n";
1444 let completion_response = indoc! {"
1445 ```animals.js
1446 <|start_of_file|>
1447 <|editable_region_start|>
1448 lorem
1449 ipsum
1450 <|editable_region_end|>
1451 ```"};
1452
1453 assert_eq!(
1454 apply_edit_prediction(buffer_content, completion_response, cx).await,
1455 "lorem\nipsum"
1456 );
1457}
1458
1459#[gpui::test]
1460async fn test_can_collect_data(cx: &mut TestAppContext) {
1461 init_test(cx);
1462
1463 let fs = project::FakeFs::new(cx.executor());
1464 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1465 .await;
1466
1467 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1468 let buffer = project
1469 .update(cx, |project, cx| {
1470 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1471 })
1472 .await
1473 .unwrap();
1474
1475 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1476 ep_store.update(cx, |ep_store, _cx| {
1477 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1478 });
1479
1480 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1481 assert_eq!(
1482 captured_request.lock().clone().unwrap().can_collect_data,
1483 true
1484 );
1485
1486 ep_store.update(cx, |ep_store, _cx| {
1487 ep_store.data_collection_choice = DataCollectionChoice::Disabled
1488 });
1489
1490 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1491 assert_eq!(
1492 captured_request.lock().clone().unwrap().can_collect_data,
1493 false
1494 );
1495}
1496
1497#[gpui::test]
1498async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1499 init_test(cx);
1500
1501 let fs = project::FakeFs::new(cx.executor());
1502 let project = Project::test(fs.clone(), [], cx).await;
1503
1504 let buffer = cx.new(|_cx| {
1505 Buffer::remote(
1506 language::BufferId::new(1).unwrap(),
1507 ReplicaId::new(1),
1508 language::Capability::ReadWrite,
1509 "fn main() {\n println!(\"Hello\");\n}",
1510 )
1511 });
1512
1513 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1514 ep_store.update(cx, |ep_store, _cx| {
1515 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1516 });
1517
1518 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1519 assert_eq!(
1520 captured_request.lock().clone().unwrap().can_collect_data,
1521 false
1522 );
1523}
1524
1525#[gpui::test]
1526async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1527 init_test(cx);
1528
1529 let fs = project::FakeFs::new(cx.executor());
1530 fs.insert_tree(
1531 path!("/project"),
1532 json!({
1533 "LICENSE": BSD_0_TXT,
1534 ".env": "SECRET_KEY=secret"
1535 }),
1536 )
1537 .await;
1538
1539 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1540 let buffer = project
1541 .update(cx, |project, cx| {
1542 project.open_local_buffer("/project/.env", cx)
1543 })
1544 .await
1545 .unwrap();
1546
1547 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1548 ep_store.update(cx, |ep_store, _cx| {
1549 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1550 });
1551
1552 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1553 assert_eq!(
1554 captured_request.lock().clone().unwrap().can_collect_data,
1555 false
1556 );
1557}
1558
1559#[gpui::test]
1560async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1561 init_test(cx);
1562
1563 let fs = project::FakeFs::new(cx.executor());
1564 let project = Project::test(fs.clone(), [], cx).await;
1565 let buffer = cx.new(|cx| Buffer::local("", cx));
1566
1567 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1568 ep_store.update(cx, |ep_store, _cx| {
1569 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1570 });
1571
1572 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1573 assert_eq!(
1574 captured_request.lock().clone().unwrap().can_collect_data,
1575 false
1576 );
1577}
1578
1579#[gpui::test]
1580async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1581 init_test(cx);
1582
1583 let fs = project::FakeFs::new(cx.executor());
1584 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1585 .await;
1586
1587 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1588 let buffer = project
1589 .update(cx, |project, cx| {
1590 project.open_local_buffer("/project/main.rs", cx)
1591 })
1592 .await
1593 .unwrap();
1594
1595 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1596 ep_store.update(cx, |ep_store, _cx| {
1597 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1598 });
1599
1600 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1601 assert_eq!(
1602 captured_request.lock().clone().unwrap().can_collect_data,
1603 false
1604 );
1605}
1606
1607#[gpui::test]
1608async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1609 init_test(cx);
1610
1611 let fs = project::FakeFs::new(cx.executor());
1612 fs.insert_tree(
1613 path!("/open_source_worktree"),
1614 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1615 )
1616 .await;
1617 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1618 .await;
1619
1620 let project = Project::test(
1621 fs.clone(),
1622 [
1623 path!("/open_source_worktree").as_ref(),
1624 path!("/closed_source_worktree").as_ref(),
1625 ],
1626 cx,
1627 )
1628 .await;
1629 let buffer = project
1630 .update(cx, |project, cx| {
1631 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1632 })
1633 .await
1634 .unwrap();
1635
1636 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1637 ep_store.update(cx, |ep_store, _cx| {
1638 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1639 });
1640
1641 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1642 assert_eq!(
1643 captured_request.lock().clone().unwrap().can_collect_data,
1644 true
1645 );
1646
1647 let closed_source_file = project
1648 .update(cx, |project, cx| {
1649 let worktree2 = project
1650 .worktree_for_root_name("closed_source_worktree", cx)
1651 .unwrap();
1652 worktree2.update(cx, |worktree2, cx| {
1653 worktree2.load_file(rel_path("main.rs"), cx)
1654 })
1655 })
1656 .await
1657 .unwrap()
1658 .file;
1659
1660 buffer.update(cx, |buffer, cx| {
1661 buffer.file_updated(closed_source_file, cx);
1662 });
1663
1664 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1665 assert_eq!(
1666 captured_request.lock().clone().unwrap().can_collect_data,
1667 false
1668 );
1669}
1670
1671#[gpui::test]
1672async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1673 init_test(cx);
1674
1675 let fs = project::FakeFs::new(cx.executor());
1676 fs.insert_tree(
1677 path!("/worktree1"),
1678 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1679 )
1680 .await;
1681 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1682 .await;
1683
1684 let project = Project::test(
1685 fs.clone(),
1686 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1687 cx,
1688 )
1689 .await;
1690 let buffer = project
1691 .update(cx, |project, cx| {
1692 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1693 })
1694 .await
1695 .unwrap();
1696 let private_buffer = project
1697 .update(cx, |project, cx| {
1698 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1699 })
1700 .await
1701 .unwrap();
1702
1703 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1704 ep_store.update(cx, |ep_store, _cx| {
1705 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1706 });
1707
1708 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1709 assert_eq!(
1710 captured_request.lock().clone().unwrap().can_collect_data,
1711 true
1712 );
1713
1714 // this has a side effect of registering the buffer to watch for edits
1715 run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1716 assert_eq!(
1717 captured_request.lock().clone().unwrap().can_collect_data,
1718 false
1719 );
1720
1721 private_buffer.update(cx, |private_buffer, cx| {
1722 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1723 });
1724
1725 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1726 assert_eq!(
1727 captured_request.lock().clone().unwrap().can_collect_data,
1728 false
1729 );
1730
1731 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1732 // included
1733 buffer.update(cx, |buffer, cx| {
1734 buffer.edit(
1735 [(
1736 0..0,
1737 " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1738 )],
1739 None,
1740 cx,
1741 );
1742 });
1743
1744 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1745 assert_eq!(
1746 captured_request.lock().clone().unwrap().can_collect_data,
1747 true
1748 );
1749}
1750
1751fn init_test(cx: &mut TestAppContext) {
1752 cx.update(|cx| {
1753 let settings_store = SettingsStore::test(cx);
1754 cx.set_global(settings_store);
1755 });
1756}
1757
1758async fn apply_edit_prediction(
1759 buffer_content: &str,
1760 completion_response: &str,
1761 cx: &mut TestAppContext,
1762) -> String {
1763 let fs = project::FakeFs::new(cx.executor());
1764 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1765 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1766 let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1767 *response.lock() = completion_response.to_string();
1768 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1769 buffer.update(cx, |buffer, cx| {
1770 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1771 });
1772 buffer.read_with(cx, |buffer, _| buffer.text())
1773}
1774
1775async fn run_edit_prediction(
1776 buffer: &Entity<Buffer>,
1777 project: &Entity<Project>,
1778 ep_store: &Entity<EditPredictionStore>,
1779 cx: &mut TestAppContext,
1780) -> EditPrediction {
1781 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1782 ep_store.update(cx, |ep_store, cx| {
1783 ep_store.register_buffer(buffer, &project, cx)
1784 });
1785 cx.background_executor.run_until_parked();
1786 let prediction_task = ep_store.update(cx, |ep_store, cx| {
1787 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
1788 });
1789 prediction_task.await.unwrap().unwrap().prediction.unwrap()
1790}
1791
1792async fn make_test_ep_store(
1793 project: &Entity<Project>,
1794 cx: &mut TestAppContext,
1795) -> (
1796 Entity<EditPredictionStore>,
1797 Arc<Mutex<Option<PredictEditsBody>>>,
1798 Arc<Mutex<String>>,
1799) {
1800 let default_response = indoc! {"
1801 ```main.rs
1802 <|start_of_file|>
1803 <|editable_region_start|>
1804 hello world
1805 <|editable_region_end|>
1806 ```"
1807 };
1808 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
1809 let completion_response: Arc<Mutex<String>> =
1810 Arc::new(Mutex::new(default_response.to_string()));
1811 let http_client = FakeHttpClient::create({
1812 let captured_request = captured_request.clone();
1813 let completion_response = completion_response.clone();
1814 let mut next_request_id = 0;
1815 move |req| {
1816 let captured_request = captured_request.clone();
1817 let completion_response = completion_response.clone();
1818 async move {
1819 match (req.method(), req.uri().path()) {
1820 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1821 .status(200)
1822 .body(
1823 serde_json::to_string(&CreateLlmTokenResponse {
1824 token: LlmToken("the-llm-token".to_string()),
1825 })
1826 .unwrap()
1827 .into(),
1828 )
1829 .unwrap()),
1830 (&Method::POST, "/predict_edits/v2") => {
1831 let mut request_body = String::new();
1832 req.into_body().read_to_string(&mut request_body).await?;
1833 *captured_request.lock() =
1834 Some(serde_json::from_str(&request_body).unwrap());
1835 next_request_id += 1;
1836 Ok(http_client::Response::builder()
1837 .status(200)
1838 .body(
1839 serde_json::to_string(&PredictEditsResponse {
1840 request_id: format!("request-{next_request_id}"),
1841 output_excerpt: completion_response.lock().clone(),
1842 })
1843 .unwrap()
1844 .into(),
1845 )
1846 .unwrap())
1847 }
1848 _ => Ok(http_client::Response::builder()
1849 .status(404)
1850 .body("Not Found".into())
1851 .unwrap()),
1852 }
1853 }
1854 }
1855 });
1856
1857 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1858 cx.update(|cx| {
1859 RefreshLlmTokenListener::register(client.clone(), cx);
1860 });
1861 let _server = FakeServer::for_client(42, &client, cx).await;
1862
1863 let ep_store = cx.new(|cx| {
1864 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
1865 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
1866
1867 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
1868 for worktree in worktrees {
1869 let worktree_id = worktree.read(cx).id();
1870 ep_store
1871 .get_or_init_project(project, cx)
1872 .license_detection_watchers
1873 .entry(worktree_id)
1874 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
1875 }
1876
1877 ep_store
1878 });
1879
1880 (ep_store, captured_request, completion_response)
1881}
1882
1883fn to_completion_edits(
1884 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
1885 buffer: &Entity<Buffer>,
1886 cx: &App,
1887) -> Vec<(Range<Anchor>, Arc<str>)> {
1888 let buffer = buffer.read(cx);
1889 iterator
1890 .into_iter()
1891 .map(|(range, text)| {
1892 (
1893 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1894 text,
1895 )
1896 })
1897 .collect()
1898}
1899
1900fn from_completion_edits(
1901 editor_edits: &[(Range<Anchor>, Arc<str>)],
1902 buffer: &Entity<Buffer>,
1903 cx: &App,
1904) -> Vec<(Range<usize>, Arc<str>)> {
1905 let buffer = buffer.read(cx);
1906 editor_edits
1907 .iter()
1908 .map(|(range, text)| {
1909 (
1910 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1911 text.clone(),
1912 )
1913 })
1914 .collect()
1915}
1916
1917#[ctor::ctor]
1918fn init_logger() {
1919 zlog::init_test();
1920}