1use super::*;
2use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
3use client::{UserStore, test::FakeServer};
4use clock::{FakeSystemClock, ReplicaId};
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
8 RejectEditPredictionsBody,
9};
10use futures::{
11 AsyncReadExt, StreamExt,
12 channel::{mpsc, oneshot},
13};
14use gpui::{
15 Entity, TestAppContext,
16 http_client::{FakeHttpClient, Response},
17};
18use indoc::indoc;
19use language::{Point, ToOffset as _};
20use lsp::LanguageServerId;
21use open_ai::Usage;
22use parking_lot::Mutex;
23use pretty_assertions::{assert_eq, assert_matches};
24use project::{FakeFs, Project};
25use serde_json::json;
26use settings::SettingsStore;
27use std::{path::Path, sync::Arc, time::Duration};
28use util::{path, rel_path::rel_path};
29use uuid::Uuid;
30use zeta_prompt::ZetaPromptInput;
31
32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
33
34#[gpui::test]
35async fn test_current_state(cx: &mut TestAppContext) {
36 let (ep_store, mut requests) = init_test_with_fake_client(cx);
37 let fs = FakeFs::new(cx.executor());
38 fs.insert_tree(
39 "/root",
40 json!({
41 "1.txt": "Hello!\nHow\nBye\n",
42 "2.txt": "Hola!\nComo\nAdios\n"
43 }),
44 )
45 .await;
46 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
47
48 let buffer1 = project
49 .update(cx, |project, cx| {
50 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
51 project.set_active_path(Some(path.clone()), cx);
52 project.open_buffer(path, cx)
53 })
54 .await
55 .unwrap();
56 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
57 let position = snapshot1.anchor_before(language::Point::new(1, 3));
58
59 ep_store.update(cx, |ep_store, cx| {
60 ep_store.register_project(&project, cx);
61 ep_store.register_buffer(&buffer1, &project, cx);
62 });
63
64 // Prediction for current file
65
66 ep_store.update(cx, |ep_store, cx| {
67 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
68 });
69 let (request, respond_tx) = requests.predict.next().await.unwrap();
70
71 respond_tx
72 .send(model_response(
73 request,
74 indoc! {r"
75 --- a/root/1.txt
76 +++ b/root/1.txt
77 @@ ... @@
78 Hello!
79 -How
80 +How are you?
81 Bye
82 "},
83 ))
84 .unwrap();
85
86 cx.run_until_parked();
87
88 ep_store.update(cx, |ep_store, cx| {
89 let prediction = ep_store
90 .prediction_at(&buffer1, None, &project, cx)
91 .unwrap();
92 assert_matches!(prediction, BufferEditPrediction::Local { .. });
93 });
94
95 ep_store.update(cx, |ep_store, _cx| {
96 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
97 });
98
99 // Prediction for diagnostic in another file
100
101 let diagnostic = lsp::Diagnostic {
102 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
103 severity: Some(lsp::DiagnosticSeverity::ERROR),
104 message: "Sentence is incomplete".to_string(),
105 ..Default::default()
106 };
107
108 project.update(cx, |project, cx| {
109 project.lsp_store().update(cx, |lsp_store, cx| {
110 lsp_store
111 .update_diagnostics(
112 LanguageServerId(0),
113 lsp::PublishDiagnosticsParams {
114 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
115 diagnostics: vec![diagnostic],
116 version: None,
117 },
118 None,
119 language::DiagnosticSourceKind::Pushed,
120 &[],
121 cx,
122 )
123 .unwrap();
124 });
125 });
126
127 let (request, respond_tx) = requests.predict.next().await.unwrap();
128 respond_tx
129 .send(model_response(
130 request,
131 indoc! {r#"
132 --- a/root/2.txt
133 +++ b/root/2.txt
134 @@ ... @@
135 Hola!
136 -Como
137 +Como estas?
138 Adios
139 "#},
140 ))
141 .unwrap();
142 cx.run_until_parked();
143
144 ep_store.update(cx, |ep_store, cx| {
145 let prediction = ep_store
146 .prediction_at(&buffer1, None, &project, cx)
147 .unwrap();
148 assert_matches!(
149 prediction,
150 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
151 );
152 });
153
154 let buffer2 = project
155 .update(cx, |project, cx| {
156 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
157 project.open_buffer(path, cx)
158 })
159 .await
160 .unwrap();
161
162 ep_store.update(cx, |ep_store, cx| {
163 let prediction = ep_store
164 .prediction_at(&buffer2, None, &project, cx)
165 .unwrap();
166 assert_matches!(prediction, BufferEditPrediction::Local { .. });
167 });
168}
169
170#[gpui::test]
171async fn test_simple_request(cx: &mut TestAppContext) {
172 let (ep_store, mut requests) = init_test_with_fake_client(cx);
173 let fs = FakeFs::new(cx.executor());
174 fs.insert_tree(
175 "/root",
176 json!({
177 "foo.md": "Hello!\nHow\nBye\n"
178 }),
179 )
180 .await;
181 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
182
183 let buffer = project
184 .update(cx, |project, cx| {
185 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
186 project.open_buffer(path, cx)
187 })
188 .await
189 .unwrap();
190 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
191 let position = snapshot.anchor_before(language::Point::new(1, 3));
192
193 let prediction_task = ep_store.update(cx, |ep_store, cx| {
194 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
195 });
196
197 let (request, respond_tx) = requests.predict.next().await.unwrap();
198
199 // TODO Put back when we have a structured request again
200 // assert_eq!(
201 // request.excerpt_path.as_ref(),
202 // Path::new(path!("root/foo.md"))
203 // );
204 // assert_eq!(
205 // request.cursor_point,
206 // Point {
207 // line: Line(1),
208 // column: 3
209 // }
210 // );
211
212 respond_tx
213 .send(model_response(
214 request,
215 indoc! { r"
216 --- a/root/foo.md
217 +++ b/root/foo.md
218 @@ ... @@
219 Hello!
220 -How
221 +How are you?
222 Bye
223 "},
224 ))
225 .unwrap();
226
227 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
228
229 assert_eq!(prediction.edits.len(), 1);
230 assert_eq!(
231 prediction.edits[0].0.to_point(&snapshot).start,
232 language::Point::new(1, 3)
233 );
234 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
235}
236
237#[gpui::test]
238async fn test_request_events(cx: &mut TestAppContext) {
239 let (ep_store, mut requests) = init_test_with_fake_client(cx);
240 let fs = FakeFs::new(cx.executor());
241 fs.insert_tree(
242 "/root",
243 json!({
244 "foo.md": "Hello!\n\nBye\n"
245 }),
246 )
247 .await;
248 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
249
250 let buffer = project
251 .update(cx, |project, cx| {
252 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
253 project.open_buffer(path, cx)
254 })
255 .await
256 .unwrap();
257
258 ep_store.update(cx, |ep_store, cx| {
259 ep_store.register_buffer(&buffer, &project, cx);
260 });
261
262 buffer.update(cx, |buffer, cx| {
263 buffer.edit(vec![(7..7, "How")], None, cx);
264 });
265
266 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
267 let position = snapshot.anchor_before(language::Point::new(1, 3));
268
269 let prediction_task = ep_store.update(cx, |ep_store, cx| {
270 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
271 });
272
273 let (request, respond_tx) = requests.predict.next().await.unwrap();
274
275 let prompt = prompt_from_request(&request);
276 assert!(
277 prompt.contains(indoc! {"
278 --- a/root/foo.md
279 +++ b/root/foo.md
280 @@ -1,3 +1,3 @@
281 Hello!
282 -
283 +How
284 Bye
285 "}),
286 "{prompt}"
287 );
288
289 respond_tx
290 .send(model_response(
291 request,
292 indoc! {r#"
293 --- a/root/foo.md
294 +++ b/root/foo.md
295 @@ ... @@
296 Hello!
297 -How
298 +How are you?
299 Bye
300 "#},
301 ))
302 .unwrap();
303
304 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
305
306 assert_eq!(prediction.edits.len(), 1);
307 assert_eq!(
308 prediction.edits[0].0.to_point(&snapshot).start,
309 language::Point::new(1, 3)
310 );
311 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
312}
313
314#[gpui::test]
315async fn test_empty_prediction(cx: &mut TestAppContext) {
316 let (ep_store, mut requests) = init_test_with_fake_client(cx);
317 let fs = FakeFs::new(cx.executor());
318 fs.insert_tree(
319 "/root",
320 json!({
321 "foo.md": "Hello!\nHow\nBye\n"
322 }),
323 )
324 .await;
325 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
326
327 let buffer = project
328 .update(cx, |project, cx| {
329 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
330 project.open_buffer(path, cx)
331 })
332 .await
333 .unwrap();
334 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
335 let position = snapshot.anchor_before(language::Point::new(1, 3));
336
337 ep_store.update(cx, |ep_store, cx| {
338 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
339 });
340
341 let (request, respond_tx) = requests.predict.next().await.unwrap();
342 let response = model_response(request, "");
343 let id = response.id.clone();
344 respond_tx.send(response).unwrap();
345
346 cx.run_until_parked();
347
348 ep_store.update(cx, |ep_store, cx| {
349 assert!(
350 ep_store
351 .prediction_at(&buffer, None, &project, cx)
352 .is_none()
353 );
354 });
355
356 // prediction is reported as rejected
357 let (reject_request, _) = requests.reject.next().await.unwrap();
358
359 assert_eq!(
360 &reject_request.rejections,
361 &[EditPredictionRejection {
362 request_id: id,
363 reason: EditPredictionRejectReason::Empty,
364 was_shown: false
365 }]
366 );
367}
368
369#[gpui::test]
370async fn test_interpolated_empty(cx: &mut TestAppContext) {
371 let (ep_store, mut requests) = init_test_with_fake_client(cx);
372 let fs = FakeFs::new(cx.executor());
373 fs.insert_tree(
374 "/root",
375 json!({
376 "foo.md": "Hello!\nHow\nBye\n"
377 }),
378 )
379 .await;
380 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
381
382 let buffer = project
383 .update(cx, |project, cx| {
384 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
385 project.open_buffer(path, cx)
386 })
387 .await
388 .unwrap();
389 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
390 let position = snapshot.anchor_before(language::Point::new(1, 3));
391
392 ep_store.update(cx, |ep_store, cx| {
393 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
394 });
395
396 let (request, respond_tx) = requests.predict.next().await.unwrap();
397
398 buffer.update(cx, |buffer, cx| {
399 buffer.set_text("Hello!\nHow are you?\nBye", cx);
400 });
401
402 let response = model_response(request, SIMPLE_DIFF);
403 let id = response.id.clone();
404 respond_tx.send(response).unwrap();
405
406 cx.run_until_parked();
407
408 ep_store.update(cx, |ep_store, cx| {
409 assert!(
410 ep_store
411 .prediction_at(&buffer, None, &project, cx)
412 .is_none()
413 );
414 });
415
416 // prediction is reported as rejected
417 let (reject_request, _) = requests.reject.next().await.unwrap();
418
419 assert_eq!(
420 &reject_request.rejections,
421 &[EditPredictionRejection {
422 request_id: id,
423 reason: EditPredictionRejectReason::InterpolatedEmpty,
424 was_shown: false
425 }]
426 );
427}
428
429const SIMPLE_DIFF: &str = indoc! { r"
430 --- a/root/foo.md
431 +++ b/root/foo.md
432 @@ ... @@
433 Hello!
434 -How
435 +How are you?
436 Bye
437"};
438
439#[gpui::test]
440async fn test_replace_current(cx: &mut TestAppContext) {
441 let (ep_store, mut requests) = init_test_with_fake_client(cx);
442 let fs = FakeFs::new(cx.executor());
443 fs.insert_tree(
444 "/root",
445 json!({
446 "foo.md": "Hello!\nHow\nBye\n"
447 }),
448 )
449 .await;
450 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
451
452 let buffer = project
453 .update(cx, |project, cx| {
454 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
455 project.open_buffer(path, cx)
456 })
457 .await
458 .unwrap();
459 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
460 let position = snapshot.anchor_before(language::Point::new(1, 3));
461
462 ep_store.update(cx, |ep_store, cx| {
463 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
464 });
465
466 let (request, respond_tx) = requests.predict.next().await.unwrap();
467 let first_response = model_response(request, SIMPLE_DIFF);
468 let first_id = first_response.id.clone();
469 respond_tx.send(first_response).unwrap();
470
471 cx.run_until_parked();
472
473 ep_store.update(cx, |ep_store, cx| {
474 assert_eq!(
475 ep_store
476 .prediction_at(&buffer, None, &project, cx)
477 .unwrap()
478 .id
479 .0,
480 first_id
481 );
482 });
483
484 // a second request is triggered
485 ep_store.update(cx, |ep_store, cx| {
486 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
487 });
488
489 let (request, respond_tx) = requests.predict.next().await.unwrap();
490 let second_response = model_response(request, SIMPLE_DIFF);
491 let second_id = second_response.id.clone();
492 respond_tx.send(second_response).unwrap();
493
494 cx.run_until_parked();
495
496 ep_store.update(cx, |ep_store, cx| {
497 // second replaces first
498 assert_eq!(
499 ep_store
500 .prediction_at(&buffer, None, &project, cx)
501 .unwrap()
502 .id
503 .0,
504 second_id
505 );
506 });
507
508 // first is reported as replaced
509 let (reject_request, _) = requests.reject.next().await.unwrap();
510
511 assert_eq!(
512 &reject_request.rejections,
513 &[EditPredictionRejection {
514 request_id: first_id,
515 reason: EditPredictionRejectReason::Replaced,
516 was_shown: false
517 }]
518 );
519}
520
521#[gpui::test]
522async fn test_current_preferred(cx: &mut TestAppContext) {
523 let (ep_store, mut requests) = init_test_with_fake_client(cx);
524 let fs = FakeFs::new(cx.executor());
525 fs.insert_tree(
526 "/root",
527 json!({
528 "foo.md": "Hello!\nHow\nBye\n"
529 }),
530 )
531 .await;
532 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
533
534 let buffer = project
535 .update(cx, |project, cx| {
536 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
537 project.open_buffer(path, cx)
538 })
539 .await
540 .unwrap();
541 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
542 let position = snapshot.anchor_before(language::Point::new(1, 3));
543
544 ep_store.update(cx, |ep_store, cx| {
545 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
546 });
547
548 let (request, respond_tx) = requests.predict.next().await.unwrap();
549 let first_response = model_response(request, SIMPLE_DIFF);
550 let first_id = first_response.id.clone();
551 respond_tx.send(first_response).unwrap();
552
553 cx.run_until_parked();
554
555 ep_store.update(cx, |ep_store, cx| {
556 assert_eq!(
557 ep_store
558 .prediction_at(&buffer, None, &project, cx)
559 .unwrap()
560 .id
561 .0,
562 first_id
563 );
564 });
565
566 // a second request is triggered
567 ep_store.update(cx, |ep_store, cx| {
568 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
569 });
570
571 let (request, respond_tx) = requests.predict.next().await.unwrap();
572 // worse than current prediction
573 let second_response = model_response(
574 request,
575 indoc! { r"
576 --- a/root/foo.md
577 +++ b/root/foo.md
578 @@ ... @@
579 Hello!
580 -How
581 +How are
582 Bye
583 "},
584 );
585 let second_id = second_response.id.clone();
586 respond_tx.send(second_response).unwrap();
587
588 cx.run_until_parked();
589
590 ep_store.update(cx, |ep_store, cx| {
591 // first is preferred over second
592 assert_eq!(
593 ep_store
594 .prediction_at(&buffer, None, &project, cx)
595 .unwrap()
596 .id
597 .0,
598 first_id
599 );
600 });
601
602 // second is reported as rejected
603 let (reject_request, _) = requests.reject.next().await.unwrap();
604
605 assert_eq!(
606 &reject_request.rejections,
607 &[EditPredictionRejection {
608 request_id: second_id,
609 reason: EditPredictionRejectReason::CurrentPreferred,
610 was_shown: false
611 }]
612 );
613}
614
615#[gpui::test]
616async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
617 let (ep_store, mut requests) = init_test_with_fake_client(cx);
618 let fs = FakeFs::new(cx.executor());
619 fs.insert_tree(
620 "/root",
621 json!({
622 "foo.md": "Hello!\nHow\nBye\n"
623 }),
624 )
625 .await;
626 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
627
628 let buffer = project
629 .update(cx, |project, cx| {
630 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
631 project.open_buffer(path, cx)
632 })
633 .await
634 .unwrap();
635 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
636 let position = snapshot.anchor_before(language::Point::new(1, 3));
637
638 // start two refresh tasks
639 ep_store.update(cx, |ep_store, cx| {
640 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
641 });
642
643 let (request1, respond_first) = requests.predict.next().await.unwrap();
644
645 ep_store.update(cx, |ep_store, cx| {
646 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
647 });
648
649 let (request, respond_second) = requests.predict.next().await.unwrap();
650
651 // wait for throttle
652 cx.run_until_parked();
653
654 // second responds first
655 let second_response = model_response(request, SIMPLE_DIFF);
656 let second_id = second_response.id.clone();
657 respond_second.send(second_response).unwrap();
658
659 cx.run_until_parked();
660
661 ep_store.update(cx, |ep_store, cx| {
662 // current prediction is second
663 assert_eq!(
664 ep_store
665 .prediction_at(&buffer, None, &project, cx)
666 .unwrap()
667 .id
668 .0,
669 second_id
670 );
671 });
672
673 let first_response = model_response(request1, SIMPLE_DIFF);
674 let first_id = first_response.id.clone();
675 respond_first.send(first_response).unwrap();
676
677 cx.run_until_parked();
678
679 ep_store.update(cx, |ep_store, cx| {
680 // current prediction is still second, since first was cancelled
681 assert_eq!(
682 ep_store
683 .prediction_at(&buffer, None, &project, cx)
684 .unwrap()
685 .id
686 .0,
687 second_id
688 );
689 });
690
691 // first is reported as rejected
692 let (reject_request, _) = requests.reject.next().await.unwrap();
693
694 cx.run_until_parked();
695
696 assert_eq!(
697 &reject_request.rejections,
698 &[EditPredictionRejection {
699 request_id: first_id,
700 reason: EditPredictionRejectReason::Canceled,
701 was_shown: false
702 }]
703 );
704}
705
706#[gpui::test]
707async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
708 let (ep_store, mut requests) = init_test_with_fake_client(cx);
709 let fs = FakeFs::new(cx.executor());
710 fs.insert_tree(
711 "/root",
712 json!({
713 "foo.md": "Hello!\nHow\nBye\n"
714 }),
715 )
716 .await;
717 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
718
719 let buffer = project
720 .update(cx, |project, cx| {
721 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
722 project.open_buffer(path, cx)
723 })
724 .await
725 .unwrap();
726 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
727 let position = snapshot.anchor_before(language::Point::new(1, 3));
728
729 // start two refresh tasks
730 ep_store.update(cx, |ep_store, cx| {
731 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
732 });
733
734 let (request1, respond_first) = requests.predict.next().await.unwrap();
735
736 ep_store.update(cx, |ep_store, cx| {
737 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
738 });
739
740 let (request2, respond_second) = requests.predict.next().await.unwrap();
741
742 // wait for throttle, so requests are sent
743 cx.run_until_parked();
744
745 ep_store.update(cx, |ep_store, cx| {
746 // start a third request
747 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
748
749 // 2 are pending, so 2nd is cancelled
750 assert_eq!(
751 ep_store
752 .get_or_init_project(&project, cx)
753 .cancelled_predictions
754 .iter()
755 .copied()
756 .collect::<Vec<_>>(),
757 [1]
758 );
759 });
760
761 // wait for throttle
762 cx.run_until_parked();
763
764 let (request3, respond_third) = requests.predict.next().await.unwrap();
765
766 let first_response = model_response(request1, SIMPLE_DIFF);
767 let first_id = first_response.id.clone();
768 respond_first.send(first_response).unwrap();
769
770 cx.run_until_parked();
771
772 ep_store.update(cx, |ep_store, cx| {
773 // current prediction is first
774 assert_eq!(
775 ep_store
776 .prediction_at(&buffer, None, &project, cx)
777 .unwrap()
778 .id
779 .0,
780 first_id
781 );
782 });
783
784 let cancelled_response = model_response(request2, SIMPLE_DIFF);
785 let cancelled_id = cancelled_response.id.clone();
786 respond_second.send(cancelled_response).unwrap();
787
788 cx.run_until_parked();
789
790 ep_store.update(cx, |ep_store, cx| {
791 // current prediction is still first, since second was cancelled
792 assert_eq!(
793 ep_store
794 .prediction_at(&buffer, None, &project, cx)
795 .unwrap()
796 .id
797 .0,
798 first_id
799 );
800 });
801
802 let third_response = model_response(request3, SIMPLE_DIFF);
803 let third_response_id = third_response.id.clone();
804 respond_third.send(third_response).unwrap();
805
806 cx.run_until_parked();
807
808 ep_store.update(cx, |ep_store, cx| {
809 // third completes and replaces first
810 assert_eq!(
811 ep_store
812 .prediction_at(&buffer, None, &project, cx)
813 .unwrap()
814 .id
815 .0,
816 third_response_id
817 );
818 });
819
820 // second is reported as rejected
821 let (reject_request, _) = requests.reject.next().await.unwrap();
822
823 cx.run_until_parked();
824
825 assert_eq!(
826 &reject_request.rejections,
827 &[
828 EditPredictionRejection {
829 request_id: cancelled_id,
830 reason: EditPredictionRejectReason::Canceled,
831 was_shown: false
832 },
833 EditPredictionRejection {
834 request_id: first_id,
835 reason: EditPredictionRejectReason::Replaced,
836 was_shown: false
837 }
838 ]
839 );
840}
841
842#[gpui::test]
843async fn test_rejections_flushing(cx: &mut TestAppContext) {
844 let (ep_store, mut requests) = init_test_with_fake_client(cx);
845
846 ep_store.update(cx, |ep_store, _cx| {
847 ep_store.reject_prediction(
848 EditPredictionId("test-1".into()),
849 EditPredictionRejectReason::Discarded,
850 false,
851 );
852 ep_store.reject_prediction(
853 EditPredictionId("test-2".into()),
854 EditPredictionRejectReason::Canceled,
855 true,
856 );
857 });
858
859 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
860 cx.run_until_parked();
861
862 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
863 respond_tx.send(()).unwrap();
864
865 // batched
866 assert_eq!(reject_request.rejections.len(), 2);
867 assert_eq!(
868 reject_request.rejections[0],
869 EditPredictionRejection {
870 request_id: "test-1".to_string(),
871 reason: EditPredictionRejectReason::Discarded,
872 was_shown: false
873 }
874 );
875 assert_eq!(
876 reject_request.rejections[1],
877 EditPredictionRejection {
878 request_id: "test-2".to_string(),
879 reason: EditPredictionRejectReason::Canceled,
880 was_shown: true
881 }
882 );
883
884 // Reaching batch size limit sends without debounce
885 ep_store.update(cx, |ep_store, _cx| {
886 for i in 0..70 {
887 ep_store.reject_prediction(
888 EditPredictionId(format!("batch-{}", i).into()),
889 EditPredictionRejectReason::Discarded,
890 false,
891 );
892 }
893 });
894
895 // First MAX/2 items are sent immediately
896 cx.run_until_parked();
897 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
898 respond_tx.send(()).unwrap();
899
900 assert_eq!(reject_request.rejections.len(), 50);
901 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
902 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
903
904 // Remaining items are debounced with the next batch
905 cx.executor().advance_clock(Duration::from_secs(15));
906 cx.run_until_parked();
907
908 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
909 respond_tx.send(()).unwrap();
910
911 assert_eq!(reject_request.rejections.len(), 20);
912 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
913 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
914
915 // Request failure
916 ep_store.update(cx, |ep_store, _cx| {
917 ep_store.reject_prediction(
918 EditPredictionId("retry-1".into()),
919 EditPredictionRejectReason::Discarded,
920 false,
921 );
922 });
923
924 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
925 cx.run_until_parked();
926
927 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
928 assert_eq!(reject_request.rejections.len(), 1);
929 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
930 // Simulate failure
931 drop(_respond_tx);
932
933 // Add another rejection
934 ep_store.update(cx, |ep_store, _cx| {
935 ep_store.reject_prediction(
936 EditPredictionId("retry-2".into()),
937 EditPredictionRejectReason::Discarded,
938 false,
939 );
940 });
941
942 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
943 cx.run_until_parked();
944
945 // Retry should include both the failed item and the new one
946 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
947 respond_tx.send(()).unwrap();
948
949 assert_eq!(reject_request.rejections.len(), 2);
950 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
951 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
952}
953
954// Skipped until we start including diagnostics in prompt
955// #[gpui::test]
956// async fn test_request_diagnostics(cx: &mut TestAppContext) {
957// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
958// let fs = FakeFs::new(cx.executor());
959// fs.insert_tree(
960// "/root",
961// json!({
962// "foo.md": "Hello!\nBye"
963// }),
964// )
965// .await;
966// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
967
968// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
969// let diagnostic = lsp::Diagnostic {
970// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
971// severity: Some(lsp::DiagnosticSeverity::ERROR),
972// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
973// ..Default::default()
974// };
975
976// project.update(cx, |project, cx| {
977// project.lsp_store().update(cx, |lsp_store, cx| {
978// // Create some diagnostics
979// lsp_store
980// .update_diagnostics(
981// LanguageServerId(0),
982// lsp::PublishDiagnosticsParams {
983// uri: path_to_buffer_uri.clone(),
984// diagnostics: vec![diagnostic],
985// version: None,
986// },
987// None,
988// language::DiagnosticSourceKind::Pushed,
989// &[],
990// cx,
991// )
992// .unwrap();
993// });
994// });
995
996// let buffer = project
997// .update(cx, |project, cx| {
998// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
999// project.open_buffer(path, cx)
1000// })
1001// .await
1002// .unwrap();
1003
1004// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1005// let position = snapshot.anchor_before(language::Point::new(0, 0));
1006
1007// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1008// ep_store.request_prediction(&project, &buffer, position, cx)
1009// });
1010
1011// let (request, _respond_tx) = req_rx.next().await.unwrap();
1012
1013// assert_eq!(request.diagnostic_groups.len(), 1);
1014// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1015// .unwrap();
1016// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1017// assert_eq!(
1018// value,
1019// json!({
1020// "entries": [{
1021// "range": {
1022// "start": 8,
1023// "end": 10
1024// },
1025// "diagnostic": {
1026// "source": null,
1027// "code": null,
1028// "code_description": null,
1029// "severity": 1,
1030// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1031// "markdown": null,
1032// "group_id": 0,
1033// "is_primary": true,
1034// "is_disk_based": false,
1035// "is_unnecessary": false,
1036// "source_kind": "Pushed",
1037// "data": null,
1038// "underline": true
1039// }
1040// }],
1041// "primary_ix": 0
1042// })
1043// );
1044// }
1045
1046// Generate a model response that would apply the given diff to the active file.
1047fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
1048 let prompt = match &request.messages[0] {
1049 open_ai::RequestMessage::User {
1050 content: open_ai::MessageContent::Plain(content),
1051 } => content,
1052 _ => panic!("unexpected request {request:?}"),
1053 };
1054
1055 let open = "<editable_region>\n";
1056 let close = "</editable_region>";
1057 let cursor = "<|user_cursor|>";
1058
1059 let start_ix = open.len() + prompt.find(open).unwrap();
1060 let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
1061 let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
1062 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1063
1064 open_ai::Response {
1065 id: Uuid::new_v4().to_string(),
1066 object: "response".into(),
1067 created: 0,
1068 model: "model".into(),
1069 choices: vec![open_ai::Choice {
1070 index: 0,
1071 message: open_ai::RequestMessage::Assistant {
1072 content: Some(open_ai::MessageContent::Plain(new_excerpt)),
1073 tool_calls: vec![],
1074 },
1075 finish_reason: None,
1076 }],
1077 usage: Usage {
1078 prompt_tokens: 0,
1079 completion_tokens: 0,
1080 total_tokens: 0,
1081 },
1082 }
1083}
1084
1085fn prompt_from_request(request: &open_ai::Request) -> &str {
1086 assert_eq!(request.messages.len(), 1);
1087 let open_ai::RequestMessage::User {
1088 content: open_ai::MessageContent::Plain(content),
1089 ..
1090 } = &request.messages[0]
1091 else {
1092 panic!(
1093 "Request does not have single user message of type Plain. {:#?}",
1094 request
1095 );
1096 };
1097 content
1098}
1099
1100struct RequestChannels {
1101 predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
1102 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1103}
1104
1105fn init_test_with_fake_client(
1106 cx: &mut TestAppContext,
1107) -> (Entity<EditPredictionStore>, RequestChannels) {
1108 cx.update(move |cx| {
1109 let settings_store = SettingsStore::test(cx);
1110 cx.set_global(settings_store);
1111 zlog::init_test();
1112
1113 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1114 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1115
1116 let http_client = FakeHttpClient::create({
1117 move |req| {
1118 let uri = req.uri().path().to_string();
1119 let mut body = req.into_body();
1120 let predict_req_tx = predict_req_tx.clone();
1121 let reject_req_tx = reject_req_tx.clone();
1122 async move {
1123 let resp = match uri.as_str() {
1124 "/client/llm_tokens" => serde_json::to_string(&json!({
1125 "token": "test"
1126 }))
1127 .unwrap(),
1128 "/predict_edits/raw" => {
1129 let mut buf = Vec::new();
1130 body.read_to_end(&mut buf).await.ok();
1131 let req = serde_json::from_slice(&buf).unwrap();
1132
1133 let (res_tx, res_rx) = oneshot::channel();
1134 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1135 serde_json::to_string(&res_rx.await?).unwrap()
1136 }
1137 "/predict_edits/reject" => {
1138 let mut buf = Vec::new();
1139 body.read_to_end(&mut buf).await.ok();
1140 let req = serde_json::from_slice(&buf).unwrap();
1141
1142 let (res_tx, res_rx) = oneshot::channel();
1143 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1144 serde_json::to_string(&res_rx.await?).unwrap()
1145 }
1146 _ => {
1147 panic!("Unexpected path: {}", uri)
1148 }
1149 };
1150
1151 Ok(Response::builder().body(resp.into()).unwrap())
1152 }
1153 }
1154 });
1155
1156 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1157 client.cloud_client().set_credentials(1, "test".into());
1158
1159 language_model::init(client.clone(), cx);
1160
1161 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1162 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1163
1164 (
1165 ep_store,
1166 RequestChannels {
1167 predict: predict_req_rx,
1168 reject: reject_req_rx,
1169 },
1170 )
1171 })
1172}
1173
1174const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1175
1176#[gpui::test]
1177async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1178 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1179 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1180 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1181 });
1182
1183 let edit_preview = cx
1184 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1185 .await;
1186
1187 let prediction = EditPrediction {
1188 edits,
1189 edit_preview,
1190 buffer: buffer.clone(),
1191 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1192 id: EditPredictionId("the-id".into()),
1193 inputs: ZetaPromptInput {
1194 events: Default::default(),
1195 related_files: Default::default(),
1196 cursor_path: Path::new("").into(),
1197 cursor_excerpt: "".into(),
1198 editable_range_in_excerpt: 0..0,
1199 cursor_offset_in_excerpt: 0,
1200 },
1201 buffer_snapshotted_at: Instant::now(),
1202 response_received_at: Instant::now(),
1203 };
1204
1205 cx.update(|cx| {
1206 assert_eq!(
1207 from_completion_edits(
1208 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1209 &buffer,
1210 cx
1211 ),
1212 vec![(2..5, "REM".into()), (9..11, "".into())]
1213 );
1214
1215 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1216 assert_eq!(
1217 from_completion_edits(
1218 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1219 &buffer,
1220 cx
1221 ),
1222 vec![(2..2, "REM".into()), (6..8, "".into())]
1223 );
1224
1225 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1226 assert_eq!(
1227 from_completion_edits(
1228 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1229 &buffer,
1230 cx
1231 ),
1232 vec![(2..5, "REM".into()), (9..11, "".into())]
1233 );
1234
1235 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1236 assert_eq!(
1237 from_completion_edits(
1238 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1239 &buffer,
1240 cx
1241 ),
1242 vec![(3..3, "EM".into()), (7..9, "".into())]
1243 );
1244
1245 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1246 assert_eq!(
1247 from_completion_edits(
1248 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1249 &buffer,
1250 cx
1251 ),
1252 vec![(4..4, "M".into()), (8..10, "".into())]
1253 );
1254
1255 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1256 assert_eq!(
1257 from_completion_edits(
1258 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1259 &buffer,
1260 cx
1261 ),
1262 vec![(9..11, "".into())]
1263 );
1264
1265 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1266 assert_eq!(
1267 from_completion_edits(
1268 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1269 &buffer,
1270 cx
1271 ),
1272 vec![(4..4, "M".into()), (8..10, "".into())]
1273 );
1274
1275 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1276 assert_eq!(
1277 from_completion_edits(
1278 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1279 &buffer,
1280 cx
1281 ),
1282 vec![(4..4, "M".into())]
1283 );
1284
1285 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1286 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1287 })
1288}
1289
1290#[gpui::test]
1291async fn test_clean_up_diff(cx: &mut TestAppContext) {
1292 init_test(cx);
1293
1294 assert_eq!(
1295 apply_edit_prediction(
1296 indoc! {"
1297 fn main() {
1298 let word_1 = \"lorem\";
1299 let range = word.len()..word.len();
1300 }
1301 "},
1302 indoc! {"
1303 <|editable_region_start|>
1304 fn main() {
1305 let word_1 = \"lorem\";
1306 let range = word_1.len()..word_1.len();
1307 }
1308
1309 <|editable_region_end|>
1310 "},
1311 cx,
1312 )
1313 .await,
1314 indoc! {"
1315 fn main() {
1316 let word_1 = \"lorem\";
1317 let range = word_1.len()..word_1.len();
1318 }
1319 "},
1320 );
1321
1322 assert_eq!(
1323 apply_edit_prediction(
1324 indoc! {"
1325 fn main() {
1326 let story = \"the quick\"
1327 }
1328 "},
1329 indoc! {"
1330 <|editable_region_start|>
1331 fn main() {
1332 let story = \"the quick brown fox jumps over the lazy dog\";
1333 }
1334
1335 <|editable_region_end|>
1336 "},
1337 cx,
1338 )
1339 .await,
1340 indoc! {"
1341 fn main() {
1342 let story = \"the quick brown fox jumps over the lazy dog\";
1343 }
1344 "},
1345 );
1346}
1347
1348#[gpui::test]
1349async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1350 init_test(cx);
1351
1352 let buffer_content = "lorem\n";
1353 let completion_response = indoc! {"
1354 ```animals.js
1355 <|start_of_file|>
1356 <|editable_region_start|>
1357 lorem
1358 ipsum
1359 <|editable_region_end|>
1360 ```"};
1361
1362 assert_eq!(
1363 apply_edit_prediction(buffer_content, completion_response, cx).await,
1364 "lorem\nipsum"
1365 );
1366}
1367
1368#[gpui::test]
1369async fn test_can_collect_data(cx: &mut TestAppContext) {
1370 init_test(cx);
1371
1372 let fs = project::FakeFs::new(cx.executor());
1373 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1374 .await;
1375
1376 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1377 let buffer = project
1378 .update(cx, |project, cx| {
1379 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1380 })
1381 .await
1382 .unwrap();
1383
1384 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1385 ep_store.update(cx, |ep_store, _cx| {
1386 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1387 });
1388
1389 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1390 assert_eq!(
1391 captured_request.lock().clone().unwrap().can_collect_data,
1392 true
1393 );
1394
1395 ep_store.update(cx, |ep_store, _cx| {
1396 ep_store.data_collection_choice = DataCollectionChoice::Disabled
1397 });
1398
1399 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1400 assert_eq!(
1401 captured_request.lock().clone().unwrap().can_collect_data,
1402 false
1403 );
1404}
1405
1406#[gpui::test]
1407async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1408 init_test(cx);
1409
1410 let fs = project::FakeFs::new(cx.executor());
1411 let project = Project::test(fs.clone(), [], cx).await;
1412
1413 let buffer = cx.new(|_cx| {
1414 Buffer::remote(
1415 language::BufferId::new(1).unwrap(),
1416 ReplicaId::new(1),
1417 language::Capability::ReadWrite,
1418 "fn main() {\n println!(\"Hello\");\n}",
1419 )
1420 });
1421
1422 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1423 ep_store.update(cx, |ep_store, _cx| {
1424 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1425 });
1426
1427 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1428 assert_eq!(
1429 captured_request.lock().clone().unwrap().can_collect_data,
1430 false
1431 );
1432}
1433
1434#[gpui::test]
1435async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1436 init_test(cx);
1437
1438 let fs = project::FakeFs::new(cx.executor());
1439 fs.insert_tree(
1440 path!("/project"),
1441 json!({
1442 "LICENSE": BSD_0_TXT,
1443 ".env": "SECRET_KEY=secret"
1444 }),
1445 )
1446 .await;
1447
1448 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1449 let buffer = project
1450 .update(cx, |project, cx| {
1451 project.open_local_buffer("/project/.env", cx)
1452 })
1453 .await
1454 .unwrap();
1455
1456 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1457 ep_store.update(cx, |ep_store, _cx| {
1458 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1459 });
1460
1461 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1462 assert_eq!(
1463 captured_request.lock().clone().unwrap().can_collect_data,
1464 false
1465 );
1466}
1467
1468#[gpui::test]
1469async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1470 init_test(cx);
1471
1472 let fs = project::FakeFs::new(cx.executor());
1473 let project = Project::test(fs.clone(), [], cx).await;
1474 let buffer = cx.new(|cx| Buffer::local("", cx));
1475
1476 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1477 ep_store.update(cx, |ep_store, _cx| {
1478 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1479 });
1480
1481 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1482 assert_eq!(
1483 captured_request.lock().clone().unwrap().can_collect_data,
1484 false
1485 );
1486}
1487
1488#[gpui::test]
1489async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1490 init_test(cx);
1491
1492 let fs = project::FakeFs::new(cx.executor());
1493 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1494 .await;
1495
1496 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1497 let buffer = project
1498 .update(cx, |project, cx| {
1499 project.open_local_buffer("/project/main.rs", cx)
1500 })
1501 .await
1502 .unwrap();
1503
1504 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1505 ep_store.update(cx, |ep_store, _cx| {
1506 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1507 });
1508
1509 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1510 assert_eq!(
1511 captured_request.lock().clone().unwrap().can_collect_data,
1512 false
1513 );
1514}
1515
1516#[gpui::test]
1517async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1518 init_test(cx);
1519
1520 let fs = project::FakeFs::new(cx.executor());
1521 fs.insert_tree(
1522 path!("/open_source_worktree"),
1523 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1524 )
1525 .await;
1526 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1527 .await;
1528
1529 let project = Project::test(
1530 fs.clone(),
1531 [
1532 path!("/open_source_worktree").as_ref(),
1533 path!("/closed_source_worktree").as_ref(),
1534 ],
1535 cx,
1536 )
1537 .await;
1538 let buffer = project
1539 .update(cx, |project, cx| {
1540 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1541 })
1542 .await
1543 .unwrap();
1544
1545 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1546 ep_store.update(cx, |ep_store, _cx| {
1547 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1548 });
1549
1550 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1551 assert_eq!(
1552 captured_request.lock().clone().unwrap().can_collect_data,
1553 true
1554 );
1555
1556 let closed_source_file = project
1557 .update(cx, |project, cx| {
1558 let worktree2 = project
1559 .worktree_for_root_name("closed_source_worktree", cx)
1560 .unwrap();
1561 worktree2.update(cx, |worktree2, cx| {
1562 worktree2.load_file(rel_path("main.rs"), cx)
1563 })
1564 })
1565 .await
1566 .unwrap()
1567 .file;
1568
1569 buffer.update(cx, |buffer, cx| {
1570 buffer.file_updated(closed_source_file, cx);
1571 });
1572
1573 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1574 assert_eq!(
1575 captured_request.lock().clone().unwrap().can_collect_data,
1576 false
1577 );
1578}
1579
1580#[gpui::test]
1581async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1582 init_test(cx);
1583
1584 let fs = project::FakeFs::new(cx.executor());
1585 fs.insert_tree(
1586 path!("/worktree1"),
1587 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1588 )
1589 .await;
1590 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1591 .await;
1592
1593 let project = Project::test(
1594 fs.clone(),
1595 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1596 cx,
1597 )
1598 .await;
1599 let buffer = project
1600 .update(cx, |project, cx| {
1601 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1602 })
1603 .await
1604 .unwrap();
1605 let private_buffer = project
1606 .update(cx, |project, cx| {
1607 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1608 })
1609 .await
1610 .unwrap();
1611
1612 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1613 ep_store.update(cx, |ep_store, _cx| {
1614 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1615 });
1616
1617 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1618 assert_eq!(
1619 captured_request.lock().clone().unwrap().can_collect_data,
1620 true
1621 );
1622
1623 // this has a side effect of registering the buffer to watch for edits
1624 run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1625 assert_eq!(
1626 captured_request.lock().clone().unwrap().can_collect_data,
1627 false
1628 );
1629
1630 private_buffer.update(cx, |private_buffer, cx| {
1631 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1632 });
1633
1634 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1635 assert_eq!(
1636 captured_request.lock().clone().unwrap().can_collect_data,
1637 false
1638 );
1639
1640 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1641 // included
1642 buffer.update(cx, |buffer, cx| {
1643 buffer.edit(
1644 [(
1645 0..0,
1646 " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1647 )],
1648 None,
1649 cx,
1650 );
1651 });
1652
1653 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1654 assert_eq!(
1655 captured_request.lock().clone().unwrap().can_collect_data,
1656 true
1657 );
1658}
1659
1660fn init_test(cx: &mut TestAppContext) {
1661 cx.update(|cx| {
1662 let settings_store = SettingsStore::test(cx);
1663 cx.set_global(settings_store);
1664 });
1665}
1666
1667async fn apply_edit_prediction(
1668 buffer_content: &str,
1669 completion_response: &str,
1670 cx: &mut TestAppContext,
1671) -> String {
1672 let fs = project::FakeFs::new(cx.executor());
1673 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1674 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1675 let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1676 *response.lock() = completion_response.to_string();
1677 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1678 buffer.update(cx, |buffer, cx| {
1679 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1680 });
1681 buffer.read_with(cx, |buffer, _| buffer.text())
1682}
1683
1684async fn run_edit_prediction(
1685 buffer: &Entity<Buffer>,
1686 project: &Entity<Project>,
1687 ep_store: &Entity<EditPredictionStore>,
1688 cx: &mut TestAppContext,
1689) -> EditPrediction {
1690 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1691 ep_store.update(cx, |ep_store, cx| {
1692 ep_store.register_buffer(buffer, &project, cx)
1693 });
1694 cx.background_executor.run_until_parked();
1695 let prediction_task = ep_store.update(cx, |ep_store, cx| {
1696 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
1697 });
1698 prediction_task.await.unwrap().unwrap().prediction.unwrap()
1699}
1700
1701async fn make_test_ep_store(
1702 project: &Entity<Project>,
1703 cx: &mut TestAppContext,
1704) -> (
1705 Entity<EditPredictionStore>,
1706 Arc<Mutex<Option<PredictEditsBody>>>,
1707 Arc<Mutex<String>>,
1708) {
1709 let default_response = indoc! {"
1710 ```main.rs
1711 <|start_of_file|>
1712 <|editable_region_start|>
1713 hello world
1714 <|editable_region_end|>
1715 ```"
1716 };
1717 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
1718 let completion_response: Arc<Mutex<String>> =
1719 Arc::new(Mutex::new(default_response.to_string()));
1720 let http_client = FakeHttpClient::create({
1721 let captured_request = captured_request.clone();
1722 let completion_response = completion_response.clone();
1723 let mut next_request_id = 0;
1724 move |req| {
1725 let captured_request = captured_request.clone();
1726 let completion_response = completion_response.clone();
1727 async move {
1728 match (req.method(), req.uri().path()) {
1729 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1730 .status(200)
1731 .body(
1732 serde_json::to_string(&CreateLlmTokenResponse {
1733 token: LlmToken("the-llm-token".to_string()),
1734 })
1735 .unwrap()
1736 .into(),
1737 )
1738 .unwrap()),
1739 (&Method::POST, "/predict_edits/v2") => {
1740 let mut request_body = String::new();
1741 req.into_body().read_to_string(&mut request_body).await?;
1742 *captured_request.lock() =
1743 Some(serde_json::from_str(&request_body).unwrap());
1744 next_request_id += 1;
1745 Ok(http_client::Response::builder()
1746 .status(200)
1747 .body(
1748 serde_json::to_string(&PredictEditsResponse {
1749 request_id: format!("request-{next_request_id}"),
1750 output_excerpt: completion_response.lock().clone(),
1751 })
1752 .unwrap()
1753 .into(),
1754 )
1755 .unwrap())
1756 }
1757 _ => Ok(http_client::Response::builder()
1758 .status(404)
1759 .body("Not Found".into())
1760 .unwrap()),
1761 }
1762 }
1763 }
1764 });
1765
1766 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1767 cx.update(|cx| {
1768 RefreshLlmTokenListener::register(client.clone(), cx);
1769 });
1770 let _server = FakeServer::for_client(42, &client, cx).await;
1771
1772 let ep_store = cx.new(|cx| {
1773 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
1774 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
1775
1776 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
1777 for worktree in worktrees {
1778 let worktree_id = worktree.read(cx).id();
1779 ep_store
1780 .get_or_init_project(project, cx)
1781 .license_detection_watchers
1782 .entry(worktree_id)
1783 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
1784 }
1785
1786 ep_store
1787 });
1788
1789 (ep_store, captured_request, completion_response)
1790}
1791
1792fn to_completion_edits(
1793 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
1794 buffer: &Entity<Buffer>,
1795 cx: &App,
1796) -> Vec<(Range<Anchor>, Arc<str>)> {
1797 let buffer = buffer.read(cx);
1798 iterator
1799 .into_iter()
1800 .map(|(range, text)| {
1801 (
1802 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1803 text,
1804 )
1805 })
1806 .collect()
1807}
1808
1809fn from_completion_edits(
1810 editor_edits: &[(Range<Anchor>, Arc<str>)],
1811 buffer: &Entity<Buffer>,
1812 cx: &App,
1813) -> Vec<(Range<usize>, Arc<str>)> {
1814 let buffer = buffer.read(cx);
1815 editor_edits
1816 .iter()
1817 .map(|(range, text)| {
1818 (
1819 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1820 text.clone(),
1821 )
1822 })
1823 .collect()
1824}
1825
1826#[ctor::ctor]
1827fn init_logger() {
1828 zlog::init_test();
1829}