1use super::*;
2use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
3use client::{UserStore, test::FakeServer};
4use clock::{FakeSystemClock, ReplicaId};
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
8 RejectEditPredictionsBody,
9};
10use futures::{
11 AsyncReadExt, StreamExt,
12 channel::{mpsc, oneshot},
13};
14use gpui::{
15 Entity, TestAppContext,
16 http_client::{FakeHttpClient, Response},
17};
18use indoc::indoc;
19use language::{Point, ToOffset as _};
20use lsp::LanguageServerId;
21use open_ai::Usage;
22use parking_lot::Mutex;
23use pretty_assertions::{assert_eq, assert_matches};
24use project::{FakeFs, Project};
25use serde_json::json;
26use settings::SettingsStore;
27use std::{path::Path, sync::Arc, time::Duration};
28use util::{path, rel_path::rel_path};
29use uuid::Uuid;
30use zeta_prompt::ZetaPromptInput;
31
32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
33
34#[gpui::test]
35async fn test_current_state(cx: &mut TestAppContext) {
36 let (ep_store, mut requests) = init_test_with_fake_client(cx);
37 let fs = FakeFs::new(cx.executor());
38 fs.insert_tree(
39 "/root",
40 json!({
41 "1.txt": "Hello!\nHow\nBye\n",
42 "2.txt": "Hola!\nComo\nAdios\n"
43 }),
44 )
45 .await;
46 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
47
48 ep_store.update(cx, |ep_store, cx| {
49 ep_store.register_project(&project, cx);
50 });
51
52 let buffer1 = project
53 .update(cx, |project, cx| {
54 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
55 project.set_active_path(Some(path.clone()), cx);
56 project.open_buffer(path, cx)
57 })
58 .await
59 .unwrap();
60 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
61 let position = snapshot1.anchor_before(language::Point::new(1, 3));
62
63 // Prediction for current file
64
65 ep_store.update(cx, |ep_store, cx| {
66 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
67 });
68 let (request, respond_tx) = requests.predict.next().await.unwrap();
69
70 respond_tx
71 .send(model_response(
72 request,
73 indoc! {r"
74 --- a/root/1.txt
75 +++ b/root/1.txt
76 @@ ... @@
77 Hello!
78 -How
79 +How are you?
80 Bye
81 "},
82 ))
83 .unwrap();
84
85 cx.run_until_parked();
86
87 ep_store.read_with(cx, |ep_store, cx| {
88 let prediction = ep_store
89 .current_prediction_for_buffer(&buffer1, &project, cx)
90 .unwrap();
91 assert_matches!(prediction, BufferEditPrediction::Local { .. });
92 });
93
94 ep_store.update(cx, |ep_store, _cx| {
95 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
96 });
97
98 // Prediction for diagnostic in another file
99
100 let diagnostic = lsp::Diagnostic {
101 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
102 severity: Some(lsp::DiagnosticSeverity::ERROR),
103 message: "Sentence is incomplete".to_string(),
104 ..Default::default()
105 };
106
107 project.update(cx, |project, cx| {
108 project.lsp_store().update(cx, |lsp_store, cx| {
109 lsp_store
110 .update_diagnostics(
111 LanguageServerId(0),
112 lsp::PublishDiagnosticsParams {
113 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
114 diagnostics: vec![diagnostic],
115 version: None,
116 },
117 None,
118 language::DiagnosticSourceKind::Pushed,
119 &[],
120 cx,
121 )
122 .unwrap();
123 });
124 });
125
126 let (request, respond_tx) = requests.predict.next().await.unwrap();
127 respond_tx
128 .send(model_response(
129 request,
130 indoc! {r#"
131 --- a/root/2.txt
132 +++ b/root/2.txt
133 @@ ... @@
134 Hola!
135 -Como
136 +Como estas?
137 Adios
138 "#},
139 ))
140 .unwrap();
141 cx.run_until_parked();
142
143 ep_store.read_with(cx, |ep_store, cx| {
144 let prediction = ep_store
145 .current_prediction_for_buffer(&buffer1, &project, cx)
146 .unwrap();
147 assert_matches!(
148 prediction,
149 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
150 );
151 });
152
153 let buffer2 = project
154 .update(cx, |project, cx| {
155 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
156 project.open_buffer(path, cx)
157 })
158 .await
159 .unwrap();
160
161 ep_store.read_with(cx, |ep_store, cx| {
162 let prediction = ep_store
163 .current_prediction_for_buffer(&buffer2, &project, cx)
164 .unwrap();
165 assert_matches!(prediction, BufferEditPrediction::Local { .. });
166 });
167}
168
169#[gpui::test]
170async fn test_simple_request(cx: &mut TestAppContext) {
171 let (ep_store, mut requests) = init_test_with_fake_client(cx);
172 let fs = FakeFs::new(cx.executor());
173 fs.insert_tree(
174 "/root",
175 json!({
176 "foo.md": "Hello!\nHow\nBye\n"
177 }),
178 )
179 .await;
180 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
181
182 let buffer = project
183 .update(cx, |project, cx| {
184 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
185 project.open_buffer(path, cx)
186 })
187 .await
188 .unwrap();
189 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
190 let position = snapshot.anchor_before(language::Point::new(1, 3));
191
192 let prediction_task = ep_store.update(cx, |ep_store, cx| {
193 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
194 });
195
196 let (request, respond_tx) = requests.predict.next().await.unwrap();
197
198 // TODO Put back when we have a structured request again
199 // assert_eq!(
200 // request.excerpt_path.as_ref(),
201 // Path::new(path!("root/foo.md"))
202 // );
203 // assert_eq!(
204 // request.cursor_point,
205 // Point {
206 // line: Line(1),
207 // column: 3
208 // }
209 // );
210
211 respond_tx
212 .send(model_response(
213 request,
214 indoc! { r"
215 --- a/root/foo.md
216 +++ b/root/foo.md
217 @@ ... @@
218 Hello!
219 -How
220 +How are you?
221 Bye
222 "},
223 ))
224 .unwrap();
225
226 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
227
228 assert_eq!(prediction.edits.len(), 1);
229 assert_eq!(
230 prediction.edits[0].0.to_point(&snapshot).start,
231 language::Point::new(1, 3)
232 );
233 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
234}
235
236#[gpui::test]
237async fn test_request_events(cx: &mut TestAppContext) {
238 let (ep_store, mut requests) = init_test_with_fake_client(cx);
239 let fs = FakeFs::new(cx.executor());
240 fs.insert_tree(
241 "/root",
242 json!({
243 "foo.md": "Hello!\n\nBye\n"
244 }),
245 )
246 .await;
247 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
248
249 let buffer = project
250 .update(cx, |project, cx| {
251 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
252 project.open_buffer(path, cx)
253 })
254 .await
255 .unwrap();
256
257 ep_store.update(cx, |ep_store, cx| {
258 ep_store.register_buffer(&buffer, &project, cx);
259 });
260
261 buffer.update(cx, |buffer, cx| {
262 buffer.edit(vec![(7..7, "How")], None, cx);
263 });
264
265 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
266 let position = snapshot.anchor_before(language::Point::new(1, 3));
267
268 let prediction_task = ep_store.update(cx, |ep_store, cx| {
269 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
270 });
271
272 let (request, respond_tx) = requests.predict.next().await.unwrap();
273
274 let prompt = prompt_from_request(&request);
275 assert!(
276 prompt.contains(indoc! {"
277 --- a/root/foo.md
278 +++ b/root/foo.md
279 @@ -1,3 +1,3 @@
280 Hello!
281 -
282 +How
283 Bye
284 "}),
285 "{prompt}"
286 );
287
288 respond_tx
289 .send(model_response(
290 request,
291 indoc! {r#"
292 --- a/root/foo.md
293 +++ b/root/foo.md
294 @@ ... @@
295 Hello!
296 -How
297 +How are you?
298 Bye
299 "#},
300 ))
301 .unwrap();
302
303 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
304
305 assert_eq!(prediction.edits.len(), 1);
306 assert_eq!(
307 prediction.edits[0].0.to_point(&snapshot).start,
308 language::Point::new(1, 3)
309 );
310 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
311}
312
313#[gpui::test]
314async fn test_empty_prediction(cx: &mut TestAppContext) {
315 let (ep_store, mut requests) = init_test_with_fake_client(cx);
316 let fs = FakeFs::new(cx.executor());
317 fs.insert_tree(
318 "/root",
319 json!({
320 "foo.md": "Hello!\nHow\nBye\n"
321 }),
322 )
323 .await;
324 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
325
326 let buffer = project
327 .update(cx, |project, cx| {
328 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
329 project.open_buffer(path, cx)
330 })
331 .await
332 .unwrap();
333 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
334 let position = snapshot.anchor_before(language::Point::new(1, 3));
335
336 ep_store.update(cx, |ep_store, cx| {
337 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
338 });
339
340 let (request, respond_tx) = requests.predict.next().await.unwrap();
341 let response = model_response(request, "");
342 let id = response.id.clone();
343 respond_tx.send(response).unwrap();
344
345 cx.run_until_parked();
346
347 ep_store.read_with(cx, |ep_store, cx| {
348 assert!(
349 ep_store
350 .current_prediction_for_buffer(&buffer, &project, cx)
351 .is_none()
352 );
353 });
354
355 // prediction is reported as rejected
356 let (reject_request, _) = requests.reject.next().await.unwrap();
357
358 assert_eq!(
359 &reject_request.rejections,
360 &[EditPredictionRejection {
361 request_id: id,
362 reason: EditPredictionRejectReason::Empty,
363 was_shown: false
364 }]
365 );
366}
367
368#[gpui::test]
369async fn test_interpolated_empty(cx: &mut TestAppContext) {
370 let (ep_store, mut requests) = init_test_with_fake_client(cx);
371 let fs = FakeFs::new(cx.executor());
372 fs.insert_tree(
373 "/root",
374 json!({
375 "foo.md": "Hello!\nHow\nBye\n"
376 }),
377 )
378 .await;
379 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
380
381 let buffer = project
382 .update(cx, |project, cx| {
383 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
384 project.open_buffer(path, cx)
385 })
386 .await
387 .unwrap();
388 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
389 let position = snapshot.anchor_before(language::Point::new(1, 3));
390
391 ep_store.update(cx, |ep_store, cx| {
392 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
393 });
394
395 let (request, respond_tx) = requests.predict.next().await.unwrap();
396
397 buffer.update(cx, |buffer, cx| {
398 buffer.set_text("Hello!\nHow are you?\nBye", cx);
399 });
400
401 let response = model_response(request, SIMPLE_DIFF);
402 let id = response.id.clone();
403 respond_tx.send(response).unwrap();
404
405 cx.run_until_parked();
406
407 ep_store.read_with(cx, |ep_store, cx| {
408 assert!(
409 ep_store
410 .current_prediction_for_buffer(&buffer, &project, cx)
411 .is_none()
412 );
413 });
414
415 // prediction is reported as rejected
416 let (reject_request, _) = requests.reject.next().await.unwrap();
417
418 assert_eq!(
419 &reject_request.rejections,
420 &[EditPredictionRejection {
421 request_id: id,
422 reason: EditPredictionRejectReason::InterpolatedEmpty,
423 was_shown: false
424 }]
425 );
426}
427
428const SIMPLE_DIFF: &str = indoc! { r"
429 --- a/root/foo.md
430 +++ b/root/foo.md
431 @@ ... @@
432 Hello!
433 -How
434 +How are you?
435 Bye
436"};
437
438#[gpui::test]
439async fn test_replace_current(cx: &mut TestAppContext) {
440 let (ep_store, mut requests) = init_test_with_fake_client(cx);
441 let fs = FakeFs::new(cx.executor());
442 fs.insert_tree(
443 "/root",
444 json!({
445 "foo.md": "Hello!\nHow\nBye\n"
446 }),
447 )
448 .await;
449 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
450
451 let buffer = project
452 .update(cx, |project, cx| {
453 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
454 project.open_buffer(path, cx)
455 })
456 .await
457 .unwrap();
458 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
459 let position = snapshot.anchor_before(language::Point::new(1, 3));
460
461 ep_store.update(cx, |ep_store, cx| {
462 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
463 });
464
465 let (request, respond_tx) = requests.predict.next().await.unwrap();
466 let first_response = model_response(request, SIMPLE_DIFF);
467 let first_id = first_response.id.clone();
468 respond_tx.send(first_response).unwrap();
469
470 cx.run_until_parked();
471
472 ep_store.read_with(cx, |ep_store, cx| {
473 assert_eq!(
474 ep_store
475 .current_prediction_for_buffer(&buffer, &project, cx)
476 .unwrap()
477 .id
478 .0,
479 first_id
480 );
481 });
482
483 // a second request is triggered
484 ep_store.update(cx, |ep_store, cx| {
485 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
486 });
487
488 let (request, respond_tx) = requests.predict.next().await.unwrap();
489 let second_response = model_response(request, SIMPLE_DIFF);
490 let second_id = second_response.id.clone();
491 respond_tx.send(second_response).unwrap();
492
493 cx.run_until_parked();
494
495 ep_store.read_with(cx, |ep_store, cx| {
496 // second replaces first
497 assert_eq!(
498 ep_store
499 .current_prediction_for_buffer(&buffer, &project, cx)
500 .unwrap()
501 .id
502 .0,
503 second_id
504 );
505 });
506
507 // first is reported as replaced
508 let (reject_request, _) = requests.reject.next().await.unwrap();
509
510 assert_eq!(
511 &reject_request.rejections,
512 &[EditPredictionRejection {
513 request_id: first_id,
514 reason: EditPredictionRejectReason::Replaced,
515 was_shown: false
516 }]
517 );
518}
519
520#[gpui::test]
521async fn test_current_preferred(cx: &mut TestAppContext) {
522 let (ep_store, mut requests) = init_test_with_fake_client(cx);
523 let fs = FakeFs::new(cx.executor());
524 fs.insert_tree(
525 "/root",
526 json!({
527 "foo.md": "Hello!\nHow\nBye\n"
528 }),
529 )
530 .await;
531 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
532
533 let buffer = project
534 .update(cx, |project, cx| {
535 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
536 project.open_buffer(path, cx)
537 })
538 .await
539 .unwrap();
540 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
541 let position = snapshot.anchor_before(language::Point::new(1, 3));
542
543 ep_store.update(cx, |ep_store, cx| {
544 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
545 });
546
547 let (request, respond_tx) = requests.predict.next().await.unwrap();
548 let first_response = model_response(request, SIMPLE_DIFF);
549 let first_id = first_response.id.clone();
550 respond_tx.send(first_response).unwrap();
551
552 cx.run_until_parked();
553
554 ep_store.read_with(cx, |ep_store, cx| {
555 assert_eq!(
556 ep_store
557 .current_prediction_for_buffer(&buffer, &project, cx)
558 .unwrap()
559 .id
560 .0,
561 first_id
562 );
563 });
564
565 // a second request is triggered
566 ep_store.update(cx, |ep_store, cx| {
567 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
568 });
569
570 let (request, respond_tx) = requests.predict.next().await.unwrap();
571 // worse than current prediction
572 let second_response = model_response(
573 request,
574 indoc! { r"
575 --- a/root/foo.md
576 +++ b/root/foo.md
577 @@ ... @@
578 Hello!
579 -How
580 +How are
581 Bye
582 "},
583 );
584 let second_id = second_response.id.clone();
585 respond_tx.send(second_response).unwrap();
586
587 cx.run_until_parked();
588
589 ep_store.read_with(cx, |ep_store, cx| {
590 // first is preferred over second
591 assert_eq!(
592 ep_store
593 .current_prediction_for_buffer(&buffer, &project, cx)
594 .unwrap()
595 .id
596 .0,
597 first_id
598 );
599 });
600
601 // second is reported as rejected
602 let (reject_request, _) = requests.reject.next().await.unwrap();
603
604 assert_eq!(
605 &reject_request.rejections,
606 &[EditPredictionRejection {
607 request_id: second_id,
608 reason: EditPredictionRejectReason::CurrentPreferred,
609 was_shown: false
610 }]
611 );
612}
613
614#[gpui::test]
615async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
616 let (ep_store, mut requests) = init_test_with_fake_client(cx);
617 let fs = FakeFs::new(cx.executor());
618 fs.insert_tree(
619 "/root",
620 json!({
621 "foo.md": "Hello!\nHow\nBye\n"
622 }),
623 )
624 .await;
625 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
626
627 let buffer = project
628 .update(cx, |project, cx| {
629 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
630 project.open_buffer(path, cx)
631 })
632 .await
633 .unwrap();
634 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
635 let position = snapshot.anchor_before(language::Point::new(1, 3));
636
637 // start two refresh tasks
638 ep_store.update(cx, |ep_store, cx| {
639 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
640 });
641
642 let (request1, respond_first) = requests.predict.next().await.unwrap();
643
644 ep_store.update(cx, |ep_store, cx| {
645 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
646 });
647
648 let (request, respond_second) = requests.predict.next().await.unwrap();
649
650 // wait for throttle
651 cx.run_until_parked();
652
653 // second responds first
654 let second_response = model_response(request, SIMPLE_DIFF);
655 let second_id = second_response.id.clone();
656 respond_second.send(second_response).unwrap();
657
658 cx.run_until_parked();
659
660 ep_store.read_with(cx, |ep_store, cx| {
661 // current prediction is second
662 assert_eq!(
663 ep_store
664 .current_prediction_for_buffer(&buffer, &project, cx)
665 .unwrap()
666 .id
667 .0,
668 second_id
669 );
670 });
671
672 let first_response = model_response(request1, SIMPLE_DIFF);
673 let first_id = first_response.id.clone();
674 respond_first.send(first_response).unwrap();
675
676 cx.run_until_parked();
677
678 ep_store.read_with(cx, |ep_store, cx| {
679 // current prediction is still second, since first was cancelled
680 assert_eq!(
681 ep_store
682 .current_prediction_for_buffer(&buffer, &project, cx)
683 .unwrap()
684 .id
685 .0,
686 second_id
687 );
688 });
689
690 // first is reported as rejected
691 let (reject_request, _) = requests.reject.next().await.unwrap();
692
693 cx.run_until_parked();
694
695 assert_eq!(
696 &reject_request.rejections,
697 &[EditPredictionRejection {
698 request_id: first_id,
699 reason: EditPredictionRejectReason::Canceled,
700 was_shown: false
701 }]
702 );
703}
704
705#[gpui::test]
706async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
707 let (ep_store, mut requests) = init_test_with_fake_client(cx);
708 let fs = FakeFs::new(cx.executor());
709 fs.insert_tree(
710 "/root",
711 json!({
712 "foo.md": "Hello!\nHow\nBye\n"
713 }),
714 )
715 .await;
716 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
717
718 let buffer = project
719 .update(cx, |project, cx| {
720 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
721 project.open_buffer(path, cx)
722 })
723 .await
724 .unwrap();
725 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
726 let position = snapshot.anchor_before(language::Point::new(1, 3));
727
728 // start two refresh tasks
729 ep_store.update(cx, |ep_store, cx| {
730 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
731 });
732
733 let (request1, respond_first) = requests.predict.next().await.unwrap();
734
735 ep_store.update(cx, |ep_store, cx| {
736 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
737 });
738
739 let (request2, respond_second) = requests.predict.next().await.unwrap();
740
741 // wait for throttle, so requests are sent
742 cx.run_until_parked();
743
744 ep_store.update(cx, |ep_store, cx| {
745 // start a third request
746 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
747
748 // 2 are pending, so 2nd is cancelled
749 assert_eq!(
750 ep_store
751 .get_or_init_project(&project, cx)
752 .cancelled_predictions
753 .iter()
754 .copied()
755 .collect::<Vec<_>>(),
756 [1]
757 );
758 });
759
760 // wait for throttle
761 cx.run_until_parked();
762
763 let (request3, respond_third) = requests.predict.next().await.unwrap();
764
765 let first_response = model_response(request1, SIMPLE_DIFF);
766 let first_id = first_response.id.clone();
767 respond_first.send(first_response).unwrap();
768
769 cx.run_until_parked();
770
771 ep_store.read_with(cx, |ep_store, cx| {
772 // current prediction is first
773 assert_eq!(
774 ep_store
775 .current_prediction_for_buffer(&buffer, &project, cx)
776 .unwrap()
777 .id
778 .0,
779 first_id
780 );
781 });
782
783 let cancelled_response = model_response(request2, SIMPLE_DIFF);
784 let cancelled_id = cancelled_response.id.clone();
785 respond_second.send(cancelled_response).unwrap();
786
787 cx.run_until_parked();
788
789 ep_store.read_with(cx, |ep_store, cx| {
790 // current prediction is still first, since second was cancelled
791 assert_eq!(
792 ep_store
793 .current_prediction_for_buffer(&buffer, &project, cx)
794 .unwrap()
795 .id
796 .0,
797 first_id
798 );
799 });
800
801 let third_response = model_response(request3, SIMPLE_DIFF);
802 let third_response_id = third_response.id.clone();
803 respond_third.send(third_response).unwrap();
804
805 cx.run_until_parked();
806
807 ep_store.read_with(cx, |ep_store, cx| {
808 // third completes and replaces first
809 assert_eq!(
810 ep_store
811 .current_prediction_for_buffer(&buffer, &project, cx)
812 .unwrap()
813 .id
814 .0,
815 third_response_id
816 );
817 });
818
819 // second is reported as rejected
820 let (reject_request, _) = requests.reject.next().await.unwrap();
821
822 cx.run_until_parked();
823
824 assert_eq!(
825 &reject_request.rejections,
826 &[
827 EditPredictionRejection {
828 request_id: cancelled_id,
829 reason: EditPredictionRejectReason::Canceled,
830 was_shown: false
831 },
832 EditPredictionRejection {
833 request_id: first_id,
834 reason: EditPredictionRejectReason::Replaced,
835 was_shown: false
836 }
837 ]
838 );
839}
840
841#[gpui::test]
842async fn test_rejections_flushing(cx: &mut TestAppContext) {
843 let (ep_store, mut requests) = init_test_with_fake_client(cx);
844
845 ep_store.update(cx, |ep_store, _cx| {
846 ep_store.reject_prediction(
847 EditPredictionId("test-1".into()),
848 EditPredictionRejectReason::Discarded,
849 false,
850 );
851 ep_store.reject_prediction(
852 EditPredictionId("test-2".into()),
853 EditPredictionRejectReason::Canceled,
854 true,
855 );
856 });
857
858 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
859 cx.run_until_parked();
860
861 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
862 respond_tx.send(()).unwrap();
863
864 // batched
865 assert_eq!(reject_request.rejections.len(), 2);
866 assert_eq!(
867 reject_request.rejections[0],
868 EditPredictionRejection {
869 request_id: "test-1".to_string(),
870 reason: EditPredictionRejectReason::Discarded,
871 was_shown: false
872 }
873 );
874 assert_eq!(
875 reject_request.rejections[1],
876 EditPredictionRejection {
877 request_id: "test-2".to_string(),
878 reason: EditPredictionRejectReason::Canceled,
879 was_shown: true
880 }
881 );
882
883 // Reaching batch size limit sends without debounce
884 ep_store.update(cx, |ep_store, _cx| {
885 for i in 0..70 {
886 ep_store.reject_prediction(
887 EditPredictionId(format!("batch-{}", i).into()),
888 EditPredictionRejectReason::Discarded,
889 false,
890 );
891 }
892 });
893
894 // First MAX/2 items are sent immediately
895 cx.run_until_parked();
896 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
897 respond_tx.send(()).unwrap();
898
899 assert_eq!(reject_request.rejections.len(), 50);
900 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
901 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
902
903 // Remaining items are debounced with the next batch
904 cx.executor().advance_clock(Duration::from_secs(15));
905 cx.run_until_parked();
906
907 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
908 respond_tx.send(()).unwrap();
909
910 assert_eq!(reject_request.rejections.len(), 20);
911 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
912 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
913
914 // Request failure
915 ep_store.update(cx, |ep_store, _cx| {
916 ep_store.reject_prediction(
917 EditPredictionId("retry-1".into()),
918 EditPredictionRejectReason::Discarded,
919 false,
920 );
921 });
922
923 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
924 cx.run_until_parked();
925
926 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
927 assert_eq!(reject_request.rejections.len(), 1);
928 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
929 // Simulate failure
930 drop(_respond_tx);
931
932 // Add another rejection
933 ep_store.update(cx, |ep_store, _cx| {
934 ep_store.reject_prediction(
935 EditPredictionId("retry-2".into()),
936 EditPredictionRejectReason::Discarded,
937 false,
938 );
939 });
940
941 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
942 cx.run_until_parked();
943
944 // Retry should include both the failed item and the new one
945 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
946 respond_tx.send(()).unwrap();
947
948 assert_eq!(reject_request.rejections.len(), 2);
949 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
950 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
951}
952
953// Skipped until we start including diagnostics in prompt
954// #[gpui::test]
955// async fn test_request_diagnostics(cx: &mut TestAppContext) {
956// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
957// let fs = FakeFs::new(cx.executor());
958// fs.insert_tree(
959// "/root",
960// json!({
961// "foo.md": "Hello!\nBye"
962// }),
963// )
964// .await;
965// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
966
967// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
968// let diagnostic = lsp::Diagnostic {
969// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
970// severity: Some(lsp::DiagnosticSeverity::ERROR),
971// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
972// ..Default::default()
973// };
974
975// project.update(cx, |project, cx| {
976// project.lsp_store().update(cx, |lsp_store, cx| {
977// // Create some diagnostics
978// lsp_store
979// .update_diagnostics(
980// LanguageServerId(0),
981// lsp::PublishDiagnosticsParams {
982// uri: path_to_buffer_uri.clone(),
983// diagnostics: vec![diagnostic],
984// version: None,
985// },
986// None,
987// language::DiagnosticSourceKind::Pushed,
988// &[],
989// cx,
990// )
991// .unwrap();
992// });
993// });
994
995// let buffer = project
996// .update(cx, |project, cx| {
997// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
998// project.open_buffer(path, cx)
999// })
1000// .await
1001// .unwrap();
1002
1003// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1004// let position = snapshot.anchor_before(language::Point::new(0, 0));
1005
1006// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1007// ep_store.request_prediction(&project, &buffer, position, cx)
1008// });
1009
1010// let (request, _respond_tx) = req_rx.next().await.unwrap();
1011
1012// assert_eq!(request.diagnostic_groups.len(), 1);
1013// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1014// .unwrap();
1015// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1016// assert_eq!(
1017// value,
1018// json!({
1019// "entries": [{
1020// "range": {
1021// "start": 8,
1022// "end": 10
1023// },
1024// "diagnostic": {
1025// "source": null,
1026// "code": null,
1027// "code_description": null,
1028// "severity": 1,
1029// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1030// "markdown": null,
1031// "group_id": 0,
1032// "is_primary": true,
1033// "is_disk_based": false,
1034// "is_unnecessary": false,
1035// "source_kind": "Pushed",
1036// "data": null,
1037// "underline": true
1038// }
1039// }],
1040// "primary_ix": 0
1041// })
1042// );
1043// }
1044
1045// Generate a model response that would apply the given diff to the active file.
1046fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
1047 let prompt = match &request.messages[0] {
1048 open_ai::RequestMessage::User {
1049 content: open_ai::MessageContent::Plain(content),
1050 } => content,
1051 _ => panic!("unexpected request {request:?}"),
1052 };
1053
1054 let open = "<editable_region>\n";
1055 let close = "</editable_region>";
1056 let cursor = "<|user_cursor|>";
1057
1058 let start_ix = open.len() + prompt.find(open).unwrap();
1059 let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
1060 let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
1061 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1062
1063 open_ai::Response {
1064 id: Uuid::new_v4().to_string(),
1065 object: "response".into(),
1066 created: 0,
1067 model: "model".into(),
1068 choices: vec![open_ai::Choice {
1069 index: 0,
1070 message: open_ai::RequestMessage::Assistant {
1071 content: Some(open_ai::MessageContent::Plain(new_excerpt)),
1072 tool_calls: vec![],
1073 },
1074 finish_reason: None,
1075 }],
1076 usage: Usage {
1077 prompt_tokens: 0,
1078 completion_tokens: 0,
1079 total_tokens: 0,
1080 },
1081 }
1082}
1083
1084fn prompt_from_request(request: &open_ai::Request) -> &str {
1085 assert_eq!(request.messages.len(), 1);
1086 let open_ai::RequestMessage::User {
1087 content: open_ai::MessageContent::Plain(content),
1088 ..
1089 } = &request.messages[0]
1090 else {
1091 panic!(
1092 "Request does not have single user message of type Plain. {:#?}",
1093 request
1094 );
1095 };
1096 content
1097}
1098
1099struct RequestChannels {
1100 predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
1101 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1102}
1103
1104fn init_test_with_fake_client(
1105 cx: &mut TestAppContext,
1106) -> (Entity<EditPredictionStore>, RequestChannels) {
1107 cx.update(move |cx| {
1108 let settings_store = SettingsStore::test(cx);
1109 cx.set_global(settings_store);
1110 zlog::init_test();
1111
1112 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1113 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1114
1115 let http_client = FakeHttpClient::create({
1116 move |req| {
1117 let uri = req.uri().path().to_string();
1118 let mut body = req.into_body();
1119 let predict_req_tx = predict_req_tx.clone();
1120 let reject_req_tx = reject_req_tx.clone();
1121 async move {
1122 let resp = match uri.as_str() {
1123 "/client/llm_tokens" => serde_json::to_string(&json!({
1124 "token": "test"
1125 }))
1126 .unwrap(),
1127 "/predict_edits/raw" => {
1128 let mut buf = Vec::new();
1129 body.read_to_end(&mut buf).await.ok();
1130 let req = serde_json::from_slice(&buf).unwrap();
1131
1132 let (res_tx, res_rx) = oneshot::channel();
1133 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1134 serde_json::to_string(&res_rx.await?).unwrap()
1135 }
1136 "/predict_edits/reject" => {
1137 let mut buf = Vec::new();
1138 body.read_to_end(&mut buf).await.ok();
1139 let req = serde_json::from_slice(&buf).unwrap();
1140
1141 let (res_tx, res_rx) = oneshot::channel();
1142 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1143 serde_json::to_string(&res_rx.await?).unwrap()
1144 }
1145 _ => {
1146 panic!("Unexpected path: {}", uri)
1147 }
1148 };
1149
1150 Ok(Response::builder().body(resp.into()).unwrap())
1151 }
1152 }
1153 });
1154
1155 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1156 client.cloud_client().set_credentials(1, "test".into());
1157
1158 language_model::init(client.clone(), cx);
1159
1160 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1161 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1162
1163 (
1164 ep_store,
1165 RequestChannels {
1166 predict: predict_req_rx,
1167 reject: reject_req_rx,
1168 },
1169 )
1170 })
1171}
1172
1173const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1174
1175#[gpui::test]
1176async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1177 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1178 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1179 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1180 });
1181
1182 let edit_preview = cx
1183 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1184 .await;
1185
1186 let prediction = EditPrediction {
1187 edits,
1188 edit_preview,
1189 buffer: buffer.clone(),
1190 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1191 id: EditPredictionId("the-id".into()),
1192 inputs: ZetaPromptInput {
1193 events: Default::default(),
1194 related_files: Default::default(),
1195 cursor_path: Path::new("").into(),
1196 cursor_excerpt: "".into(),
1197 editable_range_in_excerpt: 0..0,
1198 cursor_offset_in_excerpt: 0,
1199 },
1200 buffer_snapshotted_at: Instant::now(),
1201 response_received_at: Instant::now(),
1202 };
1203
1204 cx.update(|cx| {
1205 assert_eq!(
1206 from_completion_edits(
1207 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1208 &buffer,
1209 cx
1210 ),
1211 vec![(2..5, "REM".into()), (9..11, "".into())]
1212 );
1213
1214 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1215 assert_eq!(
1216 from_completion_edits(
1217 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1218 &buffer,
1219 cx
1220 ),
1221 vec![(2..2, "REM".into()), (6..8, "".into())]
1222 );
1223
1224 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1225 assert_eq!(
1226 from_completion_edits(
1227 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1228 &buffer,
1229 cx
1230 ),
1231 vec![(2..5, "REM".into()), (9..11, "".into())]
1232 );
1233
1234 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1235 assert_eq!(
1236 from_completion_edits(
1237 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1238 &buffer,
1239 cx
1240 ),
1241 vec![(3..3, "EM".into()), (7..9, "".into())]
1242 );
1243
1244 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1245 assert_eq!(
1246 from_completion_edits(
1247 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1248 &buffer,
1249 cx
1250 ),
1251 vec![(4..4, "M".into()), (8..10, "".into())]
1252 );
1253
1254 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1255 assert_eq!(
1256 from_completion_edits(
1257 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1258 &buffer,
1259 cx
1260 ),
1261 vec![(9..11, "".into())]
1262 );
1263
1264 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1265 assert_eq!(
1266 from_completion_edits(
1267 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1268 &buffer,
1269 cx
1270 ),
1271 vec![(4..4, "M".into()), (8..10, "".into())]
1272 );
1273
1274 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1275 assert_eq!(
1276 from_completion_edits(
1277 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1278 &buffer,
1279 cx
1280 ),
1281 vec![(4..4, "M".into())]
1282 );
1283
1284 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1285 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1286 })
1287}
1288
1289#[gpui::test]
1290async fn test_clean_up_diff(cx: &mut TestAppContext) {
1291 init_test(cx);
1292
1293 assert_eq!(
1294 apply_edit_prediction(
1295 indoc! {"
1296 fn main() {
1297 let word_1 = \"lorem\";
1298 let range = word.len()..word.len();
1299 }
1300 "},
1301 indoc! {"
1302 <|editable_region_start|>
1303 fn main() {
1304 let word_1 = \"lorem\";
1305 let range = word_1.len()..word_1.len();
1306 }
1307
1308 <|editable_region_end|>
1309 "},
1310 cx,
1311 )
1312 .await,
1313 indoc! {"
1314 fn main() {
1315 let word_1 = \"lorem\";
1316 let range = word_1.len()..word_1.len();
1317 }
1318 "},
1319 );
1320
1321 assert_eq!(
1322 apply_edit_prediction(
1323 indoc! {"
1324 fn main() {
1325 let story = \"the quick\"
1326 }
1327 "},
1328 indoc! {"
1329 <|editable_region_start|>
1330 fn main() {
1331 let story = \"the quick brown fox jumps over the lazy dog\";
1332 }
1333
1334 <|editable_region_end|>
1335 "},
1336 cx,
1337 )
1338 .await,
1339 indoc! {"
1340 fn main() {
1341 let story = \"the quick brown fox jumps over the lazy dog\";
1342 }
1343 "},
1344 );
1345}
1346
1347#[gpui::test]
1348async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1349 init_test(cx);
1350
1351 let buffer_content = "lorem\n";
1352 let completion_response = indoc! {"
1353 ```animals.js
1354 <|start_of_file|>
1355 <|editable_region_start|>
1356 lorem
1357 ipsum
1358 <|editable_region_end|>
1359 ```"};
1360
1361 assert_eq!(
1362 apply_edit_prediction(buffer_content, completion_response, cx).await,
1363 "lorem\nipsum"
1364 );
1365}
1366
1367#[gpui::test]
1368async fn test_can_collect_data(cx: &mut TestAppContext) {
1369 init_test(cx);
1370
1371 let fs = project::FakeFs::new(cx.executor());
1372 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1373 .await;
1374
1375 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1376 let buffer = project
1377 .update(cx, |project, cx| {
1378 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1379 })
1380 .await
1381 .unwrap();
1382
1383 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1384 ep_store.update(cx, |ep_store, _cx| {
1385 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1386 });
1387
1388 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1389 assert_eq!(
1390 captured_request.lock().clone().unwrap().can_collect_data,
1391 true
1392 );
1393
1394 ep_store.update(cx, |ep_store, _cx| {
1395 ep_store.data_collection_choice = DataCollectionChoice::Disabled
1396 });
1397
1398 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1399 assert_eq!(
1400 captured_request.lock().clone().unwrap().can_collect_data,
1401 false
1402 );
1403}
1404
1405#[gpui::test]
1406async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1407 init_test(cx);
1408
1409 let fs = project::FakeFs::new(cx.executor());
1410 let project = Project::test(fs.clone(), [], cx).await;
1411
1412 let buffer = cx.new(|_cx| {
1413 Buffer::remote(
1414 language::BufferId::new(1).unwrap(),
1415 ReplicaId::new(1),
1416 language::Capability::ReadWrite,
1417 "fn main() {\n println!(\"Hello\");\n}",
1418 )
1419 });
1420
1421 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1422 ep_store.update(cx, |ep_store, _cx| {
1423 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1424 });
1425
1426 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1427 assert_eq!(
1428 captured_request.lock().clone().unwrap().can_collect_data,
1429 false
1430 );
1431}
1432
1433#[gpui::test]
1434async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1435 init_test(cx);
1436
1437 let fs = project::FakeFs::new(cx.executor());
1438 fs.insert_tree(
1439 path!("/project"),
1440 json!({
1441 "LICENSE": BSD_0_TXT,
1442 ".env": "SECRET_KEY=secret"
1443 }),
1444 )
1445 .await;
1446
1447 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1448 let buffer = project
1449 .update(cx, |project, cx| {
1450 project.open_local_buffer("/project/.env", cx)
1451 })
1452 .await
1453 .unwrap();
1454
1455 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1456 ep_store.update(cx, |ep_store, _cx| {
1457 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1458 });
1459
1460 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1461 assert_eq!(
1462 captured_request.lock().clone().unwrap().can_collect_data,
1463 false
1464 );
1465}
1466
1467#[gpui::test]
1468async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1469 init_test(cx);
1470
1471 let fs = project::FakeFs::new(cx.executor());
1472 let project = Project::test(fs.clone(), [], cx).await;
1473 let buffer = cx.new(|cx| Buffer::local("", cx));
1474
1475 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1476 ep_store.update(cx, |ep_store, _cx| {
1477 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1478 });
1479
1480 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1481 assert_eq!(
1482 captured_request.lock().clone().unwrap().can_collect_data,
1483 false
1484 );
1485}
1486
1487#[gpui::test]
1488async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1489 init_test(cx);
1490
1491 let fs = project::FakeFs::new(cx.executor());
1492 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1493 .await;
1494
1495 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1496 let buffer = project
1497 .update(cx, |project, cx| {
1498 project.open_local_buffer("/project/main.rs", cx)
1499 })
1500 .await
1501 .unwrap();
1502
1503 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1504 ep_store.update(cx, |ep_store, _cx| {
1505 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1506 });
1507
1508 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1509 assert_eq!(
1510 captured_request.lock().clone().unwrap().can_collect_data,
1511 false
1512 );
1513}
1514
1515#[gpui::test]
1516async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1517 init_test(cx);
1518
1519 let fs = project::FakeFs::new(cx.executor());
1520 fs.insert_tree(
1521 path!("/open_source_worktree"),
1522 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1523 )
1524 .await;
1525 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1526 .await;
1527
1528 let project = Project::test(
1529 fs.clone(),
1530 [
1531 path!("/open_source_worktree").as_ref(),
1532 path!("/closed_source_worktree").as_ref(),
1533 ],
1534 cx,
1535 )
1536 .await;
1537 let buffer = project
1538 .update(cx, |project, cx| {
1539 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1540 })
1541 .await
1542 .unwrap();
1543
1544 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1545 ep_store.update(cx, |ep_store, _cx| {
1546 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1547 });
1548
1549 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1550 assert_eq!(
1551 captured_request.lock().clone().unwrap().can_collect_data,
1552 true
1553 );
1554
1555 let closed_source_file = project
1556 .update(cx, |project, cx| {
1557 let worktree2 = project
1558 .worktree_for_root_name("closed_source_worktree", cx)
1559 .unwrap();
1560 worktree2.update(cx, |worktree2, cx| {
1561 worktree2.load_file(rel_path("main.rs"), cx)
1562 })
1563 })
1564 .await
1565 .unwrap()
1566 .file;
1567
1568 buffer.update(cx, |buffer, cx| {
1569 buffer.file_updated(closed_source_file, cx);
1570 });
1571
1572 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1573 assert_eq!(
1574 captured_request.lock().clone().unwrap().can_collect_data,
1575 false
1576 );
1577}
1578
1579#[gpui::test]
1580async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1581 init_test(cx);
1582
1583 let fs = project::FakeFs::new(cx.executor());
1584 fs.insert_tree(
1585 path!("/worktree1"),
1586 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1587 )
1588 .await;
1589 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1590 .await;
1591
1592 let project = Project::test(
1593 fs.clone(),
1594 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1595 cx,
1596 )
1597 .await;
1598 let buffer = project
1599 .update(cx, |project, cx| {
1600 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1601 })
1602 .await
1603 .unwrap();
1604 let private_buffer = project
1605 .update(cx, |project, cx| {
1606 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1607 })
1608 .await
1609 .unwrap();
1610
1611 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1612 ep_store.update(cx, |ep_store, _cx| {
1613 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1614 });
1615
1616 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1617 assert_eq!(
1618 captured_request.lock().clone().unwrap().can_collect_data,
1619 true
1620 );
1621
1622 // this has a side effect of registering the buffer to watch for edits
1623 run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1624 assert_eq!(
1625 captured_request.lock().clone().unwrap().can_collect_data,
1626 false
1627 );
1628
1629 private_buffer.update(cx, |private_buffer, cx| {
1630 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1631 });
1632
1633 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1634 assert_eq!(
1635 captured_request.lock().clone().unwrap().can_collect_data,
1636 false
1637 );
1638
1639 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1640 // included
1641 buffer.update(cx, |buffer, cx| {
1642 buffer.edit(
1643 [(
1644 0..0,
1645 " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1646 )],
1647 None,
1648 cx,
1649 );
1650 });
1651
1652 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1653 assert_eq!(
1654 captured_request.lock().clone().unwrap().can_collect_data,
1655 true
1656 );
1657}
1658
1659fn init_test(cx: &mut TestAppContext) {
1660 cx.update(|cx| {
1661 let settings_store = SettingsStore::test(cx);
1662 cx.set_global(settings_store);
1663 });
1664}
1665
1666async fn apply_edit_prediction(
1667 buffer_content: &str,
1668 completion_response: &str,
1669 cx: &mut TestAppContext,
1670) -> String {
1671 let fs = project::FakeFs::new(cx.executor());
1672 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1673 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1674 let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1675 *response.lock() = completion_response.to_string();
1676 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1677 buffer.update(cx, |buffer, cx| {
1678 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1679 });
1680 buffer.read_with(cx, |buffer, _| buffer.text())
1681}
1682
1683async fn run_edit_prediction(
1684 buffer: &Entity<Buffer>,
1685 project: &Entity<Project>,
1686 ep_store: &Entity<EditPredictionStore>,
1687 cx: &mut TestAppContext,
1688) -> EditPrediction {
1689 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1690 ep_store.update(cx, |ep_store, cx| {
1691 ep_store.register_buffer(buffer, &project, cx)
1692 });
1693 cx.background_executor.run_until_parked();
1694 let prediction_task = ep_store.update(cx, |ep_store, cx| {
1695 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
1696 });
1697 prediction_task.await.unwrap().unwrap().prediction.unwrap()
1698}
1699
1700async fn make_test_ep_store(
1701 project: &Entity<Project>,
1702 cx: &mut TestAppContext,
1703) -> (
1704 Entity<EditPredictionStore>,
1705 Arc<Mutex<Option<PredictEditsBody>>>,
1706 Arc<Mutex<String>>,
1707) {
1708 let default_response = indoc! {"
1709 ```main.rs
1710 <|start_of_file|>
1711 <|editable_region_start|>
1712 hello world
1713 <|editable_region_end|>
1714 ```"
1715 };
1716 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
1717 let completion_response: Arc<Mutex<String>> =
1718 Arc::new(Mutex::new(default_response.to_string()));
1719 let http_client = FakeHttpClient::create({
1720 let captured_request = captured_request.clone();
1721 let completion_response = completion_response.clone();
1722 let mut next_request_id = 0;
1723 move |req| {
1724 let captured_request = captured_request.clone();
1725 let completion_response = completion_response.clone();
1726 async move {
1727 match (req.method(), req.uri().path()) {
1728 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1729 .status(200)
1730 .body(
1731 serde_json::to_string(&CreateLlmTokenResponse {
1732 token: LlmToken("the-llm-token".to_string()),
1733 })
1734 .unwrap()
1735 .into(),
1736 )
1737 .unwrap()),
1738 (&Method::POST, "/predict_edits/v2") => {
1739 let mut request_body = String::new();
1740 req.into_body().read_to_string(&mut request_body).await?;
1741 *captured_request.lock() =
1742 Some(serde_json::from_str(&request_body).unwrap());
1743 next_request_id += 1;
1744 Ok(http_client::Response::builder()
1745 .status(200)
1746 .body(
1747 serde_json::to_string(&PredictEditsResponse {
1748 request_id: format!("request-{next_request_id}"),
1749 output_excerpt: completion_response.lock().clone(),
1750 })
1751 .unwrap()
1752 .into(),
1753 )
1754 .unwrap())
1755 }
1756 _ => Ok(http_client::Response::builder()
1757 .status(404)
1758 .body("Not Found".into())
1759 .unwrap()),
1760 }
1761 }
1762 }
1763 });
1764
1765 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1766 cx.update(|cx| {
1767 RefreshLlmTokenListener::register(client.clone(), cx);
1768 });
1769 let _server = FakeServer::for_client(42, &client, cx).await;
1770
1771 let ep_store = cx.new(|cx| {
1772 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
1773 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
1774
1775 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
1776 for worktree in worktrees {
1777 let worktree_id = worktree.read(cx).id();
1778 ep_store
1779 .get_or_init_project(project, cx)
1780 .license_detection_watchers
1781 .entry(worktree_id)
1782 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
1783 }
1784
1785 ep_store
1786 });
1787
1788 (ep_store, captured_request, completion_response)
1789}
1790
1791fn to_completion_edits(
1792 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
1793 buffer: &Entity<Buffer>,
1794 cx: &App,
1795) -> Vec<(Range<Anchor>, Arc<str>)> {
1796 let buffer = buffer.read(cx);
1797 iterator
1798 .into_iter()
1799 .map(|(range, text)| {
1800 (
1801 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1802 text,
1803 )
1804 })
1805 .collect()
1806}
1807
1808fn from_completion_edits(
1809 editor_edits: &[(Range<Anchor>, Arc<str>)],
1810 buffer: &Entity<Buffer>,
1811 cx: &App,
1812) -> Vec<(Range<usize>, Arc<str>)> {
1813 let buffer = buffer.read(cx);
1814 editor_edits
1815 .iter()
1816 .map(|(range, text)| {
1817 (
1818 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1819 text.clone(),
1820 )
1821 })
1822 .collect()
1823}
1824
1825#[ctor::ctor]
1826fn init_logger() {
1827 zlog::init_test();
1828}