1use super::*;
2use crate::zeta1::MAX_EVENT_TOKENS;
3use client::{UserStore, test::FakeServer};
4use clock::{FakeSystemClock, ReplicaId};
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
8 RejectEditPredictionsBody,
9};
10use edit_prediction_context::Line;
11use futures::{
12 AsyncReadExt, StreamExt,
13 channel::{mpsc, oneshot},
14};
15use gpui::{
16 Entity, TestAppContext,
17 http_client::{FakeHttpClient, Response},
18};
19use indoc::indoc;
20use language::{Point, ToOffset as _};
21use lsp::LanguageServerId;
22use open_ai::Usage;
23use parking_lot::Mutex;
24use pretty_assertions::{assert_eq, assert_matches};
25use project::{FakeFs, Project};
26use serde_json::json;
27use settings::SettingsStore;
28use std::{path::Path, sync::Arc, time::Duration};
29use util::{path, rel_path::rel_path};
30use uuid::Uuid;
31
32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
33
34#[gpui::test]
35async fn test_current_state(cx: &mut TestAppContext) {
36 let (ep_store, mut requests) = init_test_with_fake_client(cx);
37 let fs = FakeFs::new(cx.executor());
38 fs.insert_tree(
39 "/root",
40 json!({
41 "1.txt": "Hello!\nHow\nBye\n",
42 "2.txt": "Hola!\nComo\nAdios\n"
43 }),
44 )
45 .await;
46 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
47
48 ep_store.update(cx, |ep_store, cx| {
49 ep_store.register_project(&project, cx);
50 });
51
52 let buffer1 = project
53 .update(cx, |project, cx| {
54 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
55 project.set_active_path(Some(path.clone()), cx);
56 project.open_buffer(path, cx)
57 })
58 .await
59 .unwrap();
60 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
61 let position = snapshot1.anchor_before(language::Point::new(1, 3));
62
63 // Prediction for current file
64
65 ep_store.update(cx, |ep_store, cx| {
66 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
67 });
68 let (_request, respond_tx) = requests.predict.next().await.unwrap();
69
70 respond_tx
71 .send(model_response(indoc! {r"
72 --- a/root/1.txt
73 +++ b/root/1.txt
74 @@ ... @@
75 Hello!
76 -How
77 +How are you?
78 Bye
79 "}))
80 .unwrap();
81
82 cx.run_until_parked();
83
84 ep_store.read_with(cx, |ep_store, cx| {
85 let prediction = ep_store
86 .current_prediction_for_buffer(&buffer1, &project, cx)
87 .unwrap();
88 assert_matches!(prediction, BufferEditPrediction::Local { .. });
89 });
90
91 ep_store.update(cx, |ep_store, _cx| {
92 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
93 });
94
95 // Prediction for diagnostic in another file
96
97 let diagnostic = lsp::Diagnostic {
98 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
99 severity: Some(lsp::DiagnosticSeverity::ERROR),
100 message: "Sentence is incomplete".to_string(),
101 ..Default::default()
102 };
103
104 project.update(cx, |project, cx| {
105 project.lsp_store().update(cx, |lsp_store, cx| {
106 lsp_store
107 .update_diagnostics(
108 LanguageServerId(0),
109 lsp::PublishDiagnosticsParams {
110 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
111 diagnostics: vec![diagnostic],
112 version: None,
113 },
114 None,
115 language::DiagnosticSourceKind::Pushed,
116 &[],
117 cx,
118 )
119 .unwrap();
120 });
121 });
122
123 let (_request, respond_tx) = requests.predict.next().await.unwrap();
124 respond_tx
125 .send(model_response(indoc! {r#"
126 --- a/root/2.txt
127 +++ b/root/2.txt
128 Hola!
129 -Como
130 +Como estas?
131 Adios
132 "#}))
133 .unwrap();
134 cx.run_until_parked();
135
136 ep_store.read_with(cx, |ep_store, cx| {
137 let prediction = ep_store
138 .current_prediction_for_buffer(&buffer1, &project, cx)
139 .unwrap();
140 assert_matches!(
141 prediction,
142 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
143 );
144 });
145
146 let buffer2 = project
147 .update(cx, |project, cx| {
148 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
149 project.open_buffer(path, cx)
150 })
151 .await
152 .unwrap();
153
154 ep_store.read_with(cx, |ep_store, cx| {
155 let prediction = ep_store
156 .current_prediction_for_buffer(&buffer2, &project, cx)
157 .unwrap();
158 assert_matches!(prediction, BufferEditPrediction::Local { .. });
159 });
160}
161
162#[gpui::test]
163async fn test_simple_request(cx: &mut TestAppContext) {
164 let (ep_store, mut requests) = init_test_with_fake_client(cx);
165 let fs = FakeFs::new(cx.executor());
166 fs.insert_tree(
167 "/root",
168 json!({
169 "foo.md": "Hello!\nHow\nBye\n"
170 }),
171 )
172 .await;
173 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
174
175 let buffer = project
176 .update(cx, |project, cx| {
177 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
178 project.open_buffer(path, cx)
179 })
180 .await
181 .unwrap();
182 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
183 let position = snapshot.anchor_before(language::Point::new(1, 3));
184
185 let prediction_task = ep_store.update(cx, |ep_store, cx| {
186 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
187 });
188
189 let (_, respond_tx) = requests.predict.next().await.unwrap();
190
191 // TODO Put back when we have a structured request again
192 // assert_eq!(
193 // request.excerpt_path.as_ref(),
194 // Path::new(path!("root/foo.md"))
195 // );
196 // assert_eq!(
197 // request.cursor_point,
198 // Point {
199 // line: Line(1),
200 // column: 3
201 // }
202 // );
203
204 respond_tx
205 .send(model_response(indoc! { r"
206 --- a/root/foo.md
207 +++ b/root/foo.md
208 @@ ... @@
209 Hello!
210 -How
211 +How are you?
212 Bye
213 "}))
214 .unwrap();
215
216 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
217
218 assert_eq!(prediction.edits.len(), 1);
219 assert_eq!(
220 prediction.edits[0].0.to_point(&snapshot).start,
221 language::Point::new(1, 3)
222 );
223 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
224}
225
226#[gpui::test]
227async fn test_request_events(cx: &mut TestAppContext) {
228 let (ep_store, mut requests) = init_test_with_fake_client(cx);
229 let fs = FakeFs::new(cx.executor());
230 fs.insert_tree(
231 "/root",
232 json!({
233 "foo.md": "Hello!\n\nBye\n"
234 }),
235 )
236 .await;
237 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
238
239 let buffer = project
240 .update(cx, |project, cx| {
241 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
242 project.open_buffer(path, cx)
243 })
244 .await
245 .unwrap();
246
247 ep_store.update(cx, |ep_store, cx| {
248 ep_store.register_buffer(&buffer, &project, cx);
249 });
250
251 buffer.update(cx, |buffer, cx| {
252 buffer.edit(vec![(7..7, "How")], None, cx);
253 });
254
255 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
256 let position = snapshot.anchor_before(language::Point::new(1, 3));
257
258 let prediction_task = ep_store.update(cx, |ep_store, cx| {
259 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
260 });
261
262 let (request, respond_tx) = requests.predict.next().await.unwrap();
263
264 let prompt = prompt_from_request(&request);
265 assert!(
266 prompt.contains(indoc! {"
267 --- a/root/foo.md
268 +++ b/root/foo.md
269 @@ -1,3 +1,3 @@
270 Hello!
271 -
272 +How
273 Bye
274 "}),
275 "{prompt}"
276 );
277
278 respond_tx
279 .send(model_response(indoc! {r#"
280 --- a/root/foo.md
281 +++ b/root/foo.md
282 @@ ... @@
283 Hello!
284 -How
285 +How are you?
286 Bye
287 "#}))
288 .unwrap();
289
290 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
291
292 assert_eq!(prediction.edits.len(), 1);
293 assert_eq!(
294 prediction.edits[0].0.to_point(&snapshot).start,
295 language::Point::new(1, 3)
296 );
297 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
298}
299
300#[gpui::test]
301async fn test_empty_prediction(cx: &mut TestAppContext) {
302 let (ep_store, mut requests) = init_test_with_fake_client(cx);
303 let fs = FakeFs::new(cx.executor());
304 fs.insert_tree(
305 "/root",
306 json!({
307 "foo.md": "Hello!\nHow\nBye\n"
308 }),
309 )
310 .await;
311 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
312
313 let buffer = project
314 .update(cx, |project, cx| {
315 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
316 project.open_buffer(path, cx)
317 })
318 .await
319 .unwrap();
320 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
321 let position = snapshot.anchor_before(language::Point::new(1, 3));
322
323 ep_store.update(cx, |ep_store, cx| {
324 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
325 });
326
327 const NO_OP_DIFF: &str = indoc! { r"
328 --- a/root/foo.md
329 +++ b/root/foo.md
330 @@ ... @@
331 Hello!
332 -How
333 +How
334 Bye
335 "};
336
337 let (_, respond_tx) = requests.predict.next().await.unwrap();
338 let response = model_response(NO_OP_DIFF);
339 let id = response.id.clone();
340 respond_tx.send(response).unwrap();
341
342 cx.run_until_parked();
343
344 ep_store.read_with(cx, |ep_store, cx| {
345 assert!(
346 ep_store
347 .current_prediction_for_buffer(&buffer, &project, cx)
348 .is_none()
349 );
350 });
351
352 // prediction is reported as rejected
353 let (reject_request, _) = requests.reject.next().await.unwrap();
354
355 assert_eq!(
356 &reject_request.rejections,
357 &[EditPredictionRejection {
358 request_id: id,
359 reason: EditPredictionRejectReason::Empty,
360 was_shown: false
361 }]
362 );
363}
364
365#[gpui::test]
366async fn test_interpolated_empty(cx: &mut TestAppContext) {
367 let (ep_store, mut requests) = init_test_with_fake_client(cx);
368 let fs = FakeFs::new(cx.executor());
369 fs.insert_tree(
370 "/root",
371 json!({
372 "foo.md": "Hello!\nHow\nBye\n"
373 }),
374 )
375 .await;
376 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
377
378 let buffer = project
379 .update(cx, |project, cx| {
380 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
381 project.open_buffer(path, cx)
382 })
383 .await
384 .unwrap();
385 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
386 let position = snapshot.anchor_before(language::Point::new(1, 3));
387
388 ep_store.update(cx, |ep_store, cx| {
389 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
390 });
391
392 let (_, respond_tx) = requests.predict.next().await.unwrap();
393
394 buffer.update(cx, |buffer, cx| {
395 buffer.set_text("Hello!\nHow are you?\nBye", cx);
396 });
397
398 let response = model_response(SIMPLE_DIFF);
399 let id = response.id.clone();
400 respond_tx.send(response).unwrap();
401
402 cx.run_until_parked();
403
404 ep_store.read_with(cx, |ep_store, cx| {
405 assert!(
406 ep_store
407 .current_prediction_for_buffer(&buffer, &project, cx)
408 .is_none()
409 );
410 });
411
412 // prediction is reported as rejected
413 let (reject_request, _) = requests.reject.next().await.unwrap();
414
415 assert_eq!(
416 &reject_request.rejections,
417 &[EditPredictionRejection {
418 request_id: id,
419 reason: EditPredictionRejectReason::InterpolatedEmpty,
420 was_shown: false
421 }]
422 );
423}
424
425const SIMPLE_DIFF: &str = indoc! { r"
426 --- a/root/foo.md
427 +++ b/root/foo.md
428 @@ ... @@
429 Hello!
430 -How
431 +How are you?
432 Bye
433"};
434
435#[gpui::test]
436async fn test_replace_current(cx: &mut TestAppContext) {
437 let (ep_store, mut requests) = init_test_with_fake_client(cx);
438 let fs = FakeFs::new(cx.executor());
439 fs.insert_tree(
440 "/root",
441 json!({
442 "foo.md": "Hello!\nHow\nBye\n"
443 }),
444 )
445 .await;
446 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
447
448 let buffer = project
449 .update(cx, |project, cx| {
450 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
451 project.open_buffer(path, cx)
452 })
453 .await
454 .unwrap();
455 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
456 let position = snapshot.anchor_before(language::Point::new(1, 3));
457
458 ep_store.update(cx, |ep_store, cx| {
459 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
460 });
461
462 let (_, respond_tx) = requests.predict.next().await.unwrap();
463 let first_response = model_response(SIMPLE_DIFF);
464 let first_id = first_response.id.clone();
465 respond_tx.send(first_response).unwrap();
466
467 cx.run_until_parked();
468
469 ep_store.read_with(cx, |ep_store, cx| {
470 assert_eq!(
471 ep_store
472 .current_prediction_for_buffer(&buffer, &project, cx)
473 .unwrap()
474 .id
475 .0,
476 first_id
477 );
478 });
479
480 // a second request is triggered
481 ep_store.update(cx, |ep_store, cx| {
482 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
483 });
484
485 let (_, respond_tx) = requests.predict.next().await.unwrap();
486 let second_response = model_response(SIMPLE_DIFF);
487 let second_id = second_response.id.clone();
488 respond_tx.send(second_response).unwrap();
489
490 cx.run_until_parked();
491
492 ep_store.read_with(cx, |ep_store, cx| {
493 // second replaces first
494 assert_eq!(
495 ep_store
496 .current_prediction_for_buffer(&buffer, &project, cx)
497 .unwrap()
498 .id
499 .0,
500 second_id
501 );
502 });
503
504 // first is reported as replaced
505 let (reject_request, _) = requests.reject.next().await.unwrap();
506
507 assert_eq!(
508 &reject_request.rejections,
509 &[EditPredictionRejection {
510 request_id: first_id,
511 reason: EditPredictionRejectReason::Replaced,
512 was_shown: false
513 }]
514 );
515}
516
517#[gpui::test]
518async fn test_current_preferred(cx: &mut TestAppContext) {
519 let (ep_store, mut requests) = init_test_with_fake_client(cx);
520 let fs = FakeFs::new(cx.executor());
521 fs.insert_tree(
522 "/root",
523 json!({
524 "foo.md": "Hello!\nHow\nBye\n"
525 }),
526 )
527 .await;
528 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
529
530 let buffer = project
531 .update(cx, |project, cx| {
532 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
533 project.open_buffer(path, cx)
534 })
535 .await
536 .unwrap();
537 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
538 let position = snapshot.anchor_before(language::Point::new(1, 3));
539
540 ep_store.update(cx, |ep_store, cx| {
541 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
542 });
543
544 let (_, respond_tx) = requests.predict.next().await.unwrap();
545 let first_response = model_response(SIMPLE_DIFF);
546 let first_id = first_response.id.clone();
547 respond_tx.send(first_response).unwrap();
548
549 cx.run_until_parked();
550
551 ep_store.read_with(cx, |ep_store, cx| {
552 assert_eq!(
553 ep_store
554 .current_prediction_for_buffer(&buffer, &project, cx)
555 .unwrap()
556 .id
557 .0,
558 first_id
559 );
560 });
561
562 // a second request is triggered
563 ep_store.update(cx, |ep_store, cx| {
564 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
565 });
566
567 let (_, respond_tx) = requests.predict.next().await.unwrap();
568 // worse than current prediction
569 let second_response = model_response(indoc! { r"
570 --- a/root/foo.md
571 +++ b/root/foo.md
572 @@ ... @@
573 Hello!
574 -How
575 +How are
576 Bye
577 "});
578 let second_id = second_response.id.clone();
579 respond_tx.send(second_response).unwrap();
580
581 cx.run_until_parked();
582
583 ep_store.read_with(cx, |ep_store, cx| {
584 // first is preferred over second
585 assert_eq!(
586 ep_store
587 .current_prediction_for_buffer(&buffer, &project, cx)
588 .unwrap()
589 .id
590 .0,
591 first_id
592 );
593 });
594
595 // second is reported as rejected
596 let (reject_request, _) = requests.reject.next().await.unwrap();
597
598 assert_eq!(
599 &reject_request.rejections,
600 &[EditPredictionRejection {
601 request_id: second_id,
602 reason: EditPredictionRejectReason::CurrentPreferred,
603 was_shown: false
604 }]
605 );
606}
607
608#[gpui::test]
609async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
610 let (ep_store, mut requests) = init_test_with_fake_client(cx);
611 let fs = FakeFs::new(cx.executor());
612 fs.insert_tree(
613 "/root",
614 json!({
615 "foo.md": "Hello!\nHow\nBye\n"
616 }),
617 )
618 .await;
619 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
620
621 let buffer = project
622 .update(cx, |project, cx| {
623 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
624 project.open_buffer(path, cx)
625 })
626 .await
627 .unwrap();
628 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
629 let position = snapshot.anchor_before(language::Point::new(1, 3));
630
631 // start two refresh tasks
632 ep_store.update(cx, |ep_store, cx| {
633 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
634 });
635
636 let (_, respond_first) = requests.predict.next().await.unwrap();
637
638 ep_store.update(cx, |ep_store, cx| {
639 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
640 });
641
642 let (_, respond_second) = requests.predict.next().await.unwrap();
643
644 // wait for throttle
645 cx.run_until_parked();
646
647 // second responds first
648 let second_response = model_response(SIMPLE_DIFF);
649 let second_id = second_response.id.clone();
650 respond_second.send(second_response).unwrap();
651
652 cx.run_until_parked();
653
654 ep_store.read_with(cx, |ep_store, cx| {
655 // current prediction is second
656 assert_eq!(
657 ep_store
658 .current_prediction_for_buffer(&buffer, &project, cx)
659 .unwrap()
660 .id
661 .0,
662 second_id
663 );
664 });
665
666 let first_response = model_response(SIMPLE_DIFF);
667 let first_id = first_response.id.clone();
668 respond_first.send(first_response).unwrap();
669
670 cx.run_until_parked();
671
672 ep_store.read_with(cx, |ep_store, cx| {
673 // current prediction is still second, since first was cancelled
674 assert_eq!(
675 ep_store
676 .current_prediction_for_buffer(&buffer, &project, cx)
677 .unwrap()
678 .id
679 .0,
680 second_id
681 );
682 });
683
684 // first is reported as rejected
685 let (reject_request, _) = requests.reject.next().await.unwrap();
686
687 cx.run_until_parked();
688
689 assert_eq!(
690 &reject_request.rejections,
691 &[EditPredictionRejection {
692 request_id: first_id,
693 reason: EditPredictionRejectReason::Canceled,
694 was_shown: false
695 }]
696 );
697}
698
699#[gpui::test]
700async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
701 let (ep_store, mut requests) = init_test_with_fake_client(cx);
702 let fs = FakeFs::new(cx.executor());
703 fs.insert_tree(
704 "/root",
705 json!({
706 "foo.md": "Hello!\nHow\nBye\n"
707 }),
708 )
709 .await;
710 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
711
712 let buffer = project
713 .update(cx, |project, cx| {
714 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
715 project.open_buffer(path, cx)
716 })
717 .await
718 .unwrap();
719 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
720 let position = snapshot.anchor_before(language::Point::new(1, 3));
721
722 // start two refresh tasks
723 ep_store.update(cx, |ep_store, cx| {
724 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
725 });
726
727 let (_, respond_first) = requests.predict.next().await.unwrap();
728
729 ep_store.update(cx, |ep_store, cx| {
730 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
731 });
732
733 let (_, respond_second) = requests.predict.next().await.unwrap();
734
735 // wait for throttle, so requests are sent
736 cx.run_until_parked();
737
738 ep_store.update(cx, |ep_store, cx| {
739 // start a third request
740 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
741
742 // 2 are pending, so 2nd is cancelled
743 assert_eq!(
744 ep_store
745 .get_or_init_project(&project, cx)
746 .cancelled_predictions
747 .iter()
748 .copied()
749 .collect::<Vec<_>>(),
750 [1]
751 );
752 });
753
754 // wait for throttle
755 cx.run_until_parked();
756
757 let (_, respond_third) = requests.predict.next().await.unwrap();
758
759 let first_response = model_response(SIMPLE_DIFF);
760 let first_id = first_response.id.clone();
761 respond_first.send(first_response).unwrap();
762
763 cx.run_until_parked();
764
765 ep_store.read_with(cx, |ep_store, cx| {
766 // current prediction is first
767 assert_eq!(
768 ep_store
769 .current_prediction_for_buffer(&buffer, &project, cx)
770 .unwrap()
771 .id
772 .0,
773 first_id
774 );
775 });
776
777 let cancelled_response = model_response(SIMPLE_DIFF);
778 let cancelled_id = cancelled_response.id.clone();
779 respond_second.send(cancelled_response).unwrap();
780
781 cx.run_until_parked();
782
783 ep_store.read_with(cx, |ep_store, cx| {
784 // current prediction is still first, since second was cancelled
785 assert_eq!(
786 ep_store
787 .current_prediction_for_buffer(&buffer, &project, cx)
788 .unwrap()
789 .id
790 .0,
791 first_id
792 );
793 });
794
795 let third_response = model_response(SIMPLE_DIFF);
796 let third_response_id = third_response.id.clone();
797 respond_third.send(third_response).unwrap();
798
799 cx.run_until_parked();
800
801 ep_store.read_with(cx, |ep_store, cx| {
802 // third completes and replaces first
803 assert_eq!(
804 ep_store
805 .current_prediction_for_buffer(&buffer, &project, cx)
806 .unwrap()
807 .id
808 .0,
809 third_response_id
810 );
811 });
812
813 // second is reported as rejected
814 let (reject_request, _) = requests.reject.next().await.unwrap();
815
816 cx.run_until_parked();
817
818 assert_eq!(
819 &reject_request.rejections,
820 &[
821 EditPredictionRejection {
822 request_id: cancelled_id,
823 reason: EditPredictionRejectReason::Canceled,
824 was_shown: false
825 },
826 EditPredictionRejection {
827 request_id: first_id,
828 reason: EditPredictionRejectReason::Replaced,
829 was_shown: false
830 }
831 ]
832 );
833}
834
835#[gpui::test]
836async fn test_rejections_flushing(cx: &mut TestAppContext) {
837 let (ep_store, mut requests) = init_test_with_fake_client(cx);
838
839 ep_store.update(cx, |ep_store, _cx| {
840 ep_store.reject_prediction(
841 EditPredictionId("test-1".into()),
842 EditPredictionRejectReason::Discarded,
843 false,
844 );
845 ep_store.reject_prediction(
846 EditPredictionId("test-2".into()),
847 EditPredictionRejectReason::Canceled,
848 true,
849 );
850 });
851
852 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
853 cx.run_until_parked();
854
855 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
856 respond_tx.send(()).unwrap();
857
858 // batched
859 assert_eq!(reject_request.rejections.len(), 2);
860 assert_eq!(
861 reject_request.rejections[0],
862 EditPredictionRejection {
863 request_id: "test-1".to_string(),
864 reason: EditPredictionRejectReason::Discarded,
865 was_shown: false
866 }
867 );
868 assert_eq!(
869 reject_request.rejections[1],
870 EditPredictionRejection {
871 request_id: "test-2".to_string(),
872 reason: EditPredictionRejectReason::Canceled,
873 was_shown: true
874 }
875 );
876
877 // Reaching batch size limit sends without debounce
878 ep_store.update(cx, |ep_store, _cx| {
879 for i in 0..70 {
880 ep_store.reject_prediction(
881 EditPredictionId(format!("batch-{}", i).into()),
882 EditPredictionRejectReason::Discarded,
883 false,
884 );
885 }
886 });
887
888 // First MAX/2 items are sent immediately
889 cx.run_until_parked();
890 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
891 respond_tx.send(()).unwrap();
892
893 assert_eq!(reject_request.rejections.len(), 50);
894 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
895 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
896
897 // Remaining items are debounced with the next batch
898 cx.executor().advance_clock(Duration::from_secs(15));
899 cx.run_until_parked();
900
901 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
902 respond_tx.send(()).unwrap();
903
904 assert_eq!(reject_request.rejections.len(), 20);
905 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
906 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
907
908 // Request failure
909 ep_store.update(cx, |ep_store, _cx| {
910 ep_store.reject_prediction(
911 EditPredictionId("retry-1".into()),
912 EditPredictionRejectReason::Discarded,
913 false,
914 );
915 });
916
917 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
918 cx.run_until_parked();
919
920 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
921 assert_eq!(reject_request.rejections.len(), 1);
922 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
923 // Simulate failure
924 drop(_respond_tx);
925
926 // Add another rejection
927 ep_store.update(cx, |ep_store, _cx| {
928 ep_store.reject_prediction(
929 EditPredictionId("retry-2".into()),
930 EditPredictionRejectReason::Discarded,
931 false,
932 );
933 });
934
935 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
936 cx.run_until_parked();
937
938 // Retry should include both the failed item and the new one
939 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
940 respond_tx.send(()).unwrap();
941
942 assert_eq!(reject_request.rejections.len(), 2);
943 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
944 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
945}
946
947// Skipped until we start including diagnostics in prompt
948// #[gpui::test]
949// async fn test_request_diagnostics(cx: &mut TestAppContext) {
950// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
951// let fs = FakeFs::new(cx.executor());
952// fs.insert_tree(
953// "/root",
954// json!({
955// "foo.md": "Hello!\nBye"
956// }),
957// )
958// .await;
959// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
960
961// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
962// let diagnostic = lsp::Diagnostic {
963// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
964// severity: Some(lsp::DiagnosticSeverity::ERROR),
965// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
966// ..Default::default()
967// };
968
969// project.update(cx, |project, cx| {
970// project.lsp_store().update(cx, |lsp_store, cx| {
971// // Create some diagnostics
972// lsp_store
973// .update_diagnostics(
974// LanguageServerId(0),
975// lsp::PublishDiagnosticsParams {
976// uri: path_to_buffer_uri.clone(),
977// diagnostics: vec![diagnostic],
978// version: None,
979// },
980// None,
981// language::DiagnosticSourceKind::Pushed,
982// &[],
983// cx,
984// )
985// .unwrap();
986// });
987// });
988
989// let buffer = project
990// .update(cx, |project, cx| {
991// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
992// project.open_buffer(path, cx)
993// })
994// .await
995// .unwrap();
996
997// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
998// let position = snapshot.anchor_before(language::Point::new(0, 0));
999
1000// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1001// ep_store.request_prediction(&project, &buffer, position, cx)
1002// });
1003
1004// let (request, _respond_tx) = req_rx.next().await.unwrap();
1005
1006// assert_eq!(request.diagnostic_groups.len(), 1);
1007// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1008// .unwrap();
1009// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1010// assert_eq!(
1011// value,
1012// json!({
1013// "entries": [{
1014// "range": {
1015// "start": 8,
1016// "end": 10
1017// },
1018// "diagnostic": {
1019// "source": null,
1020// "code": null,
1021// "code_description": null,
1022// "severity": 1,
1023// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1024// "markdown": null,
1025// "group_id": 0,
1026// "is_primary": true,
1027// "is_disk_based": false,
1028// "is_unnecessary": false,
1029// "source_kind": "Pushed",
1030// "data": null,
1031// "underline": true
1032// }
1033// }],
1034// "primary_ix": 0
1035// })
1036// );
1037// }
1038
1039fn model_response(text: &str) -> open_ai::Response {
1040 open_ai::Response {
1041 id: Uuid::new_v4().to_string(),
1042 object: "response".into(),
1043 created: 0,
1044 model: "model".into(),
1045 choices: vec![open_ai::Choice {
1046 index: 0,
1047 message: open_ai::RequestMessage::Assistant {
1048 content: Some(open_ai::MessageContent::Plain(text.to_string())),
1049 tool_calls: vec![],
1050 },
1051 finish_reason: None,
1052 }],
1053 usage: Usage {
1054 prompt_tokens: 0,
1055 completion_tokens: 0,
1056 total_tokens: 0,
1057 },
1058 }
1059}
1060
1061fn prompt_from_request(request: &open_ai::Request) -> &str {
1062 assert_eq!(request.messages.len(), 1);
1063 let open_ai::RequestMessage::User {
1064 content: open_ai::MessageContent::Plain(content),
1065 ..
1066 } = &request.messages[0]
1067 else {
1068 panic!(
1069 "Request does not have single user message of type Plain. {:#?}",
1070 request
1071 );
1072 };
1073 content
1074}
1075
1076struct RequestChannels {
1077 predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
1078 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1079}
1080
1081fn init_test_with_fake_client(
1082 cx: &mut TestAppContext,
1083) -> (Entity<EditPredictionStore>, RequestChannels) {
1084 cx.update(move |cx| {
1085 let settings_store = SettingsStore::test(cx);
1086 cx.set_global(settings_store);
1087 zlog::init_test();
1088
1089 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1090 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1091
1092 let http_client = FakeHttpClient::create({
1093 move |req| {
1094 let uri = req.uri().path().to_string();
1095 let mut body = req.into_body();
1096 let predict_req_tx = predict_req_tx.clone();
1097 let reject_req_tx = reject_req_tx.clone();
1098 async move {
1099 let resp = match uri.as_str() {
1100 "/client/llm_tokens" => serde_json::to_string(&json!({
1101 "token": "test"
1102 }))
1103 .unwrap(),
1104 "/predict_edits/raw" => {
1105 let mut buf = Vec::new();
1106 body.read_to_end(&mut buf).await.ok();
1107 let req = serde_json::from_slice(&buf).unwrap();
1108
1109 let (res_tx, res_rx) = oneshot::channel();
1110 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1111 serde_json::to_string(&res_rx.await?).unwrap()
1112 }
1113 "/predict_edits/reject" => {
1114 let mut buf = Vec::new();
1115 body.read_to_end(&mut buf).await.ok();
1116 let req = serde_json::from_slice(&buf).unwrap();
1117
1118 let (res_tx, res_rx) = oneshot::channel();
1119 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1120 serde_json::to_string(&res_rx.await?).unwrap()
1121 }
1122 _ => {
1123 panic!("Unexpected path: {}", uri)
1124 }
1125 };
1126
1127 Ok(Response::builder().body(resp.into()).unwrap())
1128 }
1129 }
1130 });
1131
1132 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1133 client.cloud_client().set_credentials(1, "test".into());
1134
1135 language_model::init(client.clone(), cx);
1136
1137 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1138 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1139
1140 (
1141 ep_store,
1142 RequestChannels {
1143 predict: predict_req_rx,
1144 reject: reject_req_rx,
1145 },
1146 )
1147 })
1148}
1149
1150const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1151
1152#[gpui::test]
1153async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1154 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1155 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1156 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1157 });
1158
1159 let edit_preview = cx
1160 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1161 .await;
1162
1163 let completion = EditPrediction {
1164 edits,
1165 edit_preview,
1166 buffer: buffer.clone(),
1167 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1168 id: EditPredictionId("the-id".into()),
1169 inputs: EditPredictionInputs {
1170 events: Default::default(),
1171 included_files: Default::default(),
1172 cursor_point: cloud_llm_client::predict_edits_v3::Point {
1173 line: Line(0),
1174 column: 0,
1175 },
1176 cursor_path: Path::new("").into(),
1177 },
1178 buffer_snapshotted_at: Instant::now(),
1179 response_received_at: Instant::now(),
1180 };
1181
1182 cx.update(|cx| {
1183 assert_eq!(
1184 from_completion_edits(
1185 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1186 &buffer,
1187 cx
1188 ),
1189 vec![(2..5, "REM".into()), (9..11, "".into())]
1190 );
1191
1192 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1193 assert_eq!(
1194 from_completion_edits(
1195 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1196 &buffer,
1197 cx
1198 ),
1199 vec![(2..2, "REM".into()), (6..8, "".into())]
1200 );
1201
1202 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1203 assert_eq!(
1204 from_completion_edits(
1205 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1206 &buffer,
1207 cx
1208 ),
1209 vec![(2..5, "REM".into()), (9..11, "".into())]
1210 );
1211
1212 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1213 assert_eq!(
1214 from_completion_edits(
1215 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1216 &buffer,
1217 cx
1218 ),
1219 vec![(3..3, "EM".into()), (7..9, "".into())]
1220 );
1221
1222 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1223 assert_eq!(
1224 from_completion_edits(
1225 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1226 &buffer,
1227 cx
1228 ),
1229 vec![(4..4, "M".into()), (8..10, "".into())]
1230 );
1231
1232 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1233 assert_eq!(
1234 from_completion_edits(
1235 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1236 &buffer,
1237 cx
1238 ),
1239 vec![(9..11, "".into())]
1240 );
1241
1242 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1243 assert_eq!(
1244 from_completion_edits(
1245 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1246 &buffer,
1247 cx
1248 ),
1249 vec![(4..4, "M".into()), (8..10, "".into())]
1250 );
1251
1252 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1253 assert_eq!(
1254 from_completion_edits(
1255 &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1256 &buffer,
1257 cx
1258 ),
1259 vec![(4..4, "M".into())]
1260 );
1261
1262 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1263 assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1264 })
1265}
1266
1267#[gpui::test]
1268async fn test_clean_up_diff(cx: &mut TestAppContext) {
1269 init_test(cx);
1270
1271 assert_eq!(
1272 apply_edit_prediction(
1273 indoc! {"
1274 fn main() {
1275 let word_1 = \"lorem\";
1276 let range = word.len()..word.len();
1277 }
1278 "},
1279 indoc! {"
1280 <|editable_region_start|>
1281 fn main() {
1282 let word_1 = \"lorem\";
1283 let range = word_1.len()..word_1.len();
1284 }
1285
1286 <|editable_region_end|>
1287 "},
1288 cx,
1289 )
1290 .await,
1291 indoc! {"
1292 fn main() {
1293 let word_1 = \"lorem\";
1294 let range = word_1.len()..word_1.len();
1295 }
1296 "},
1297 );
1298
1299 assert_eq!(
1300 apply_edit_prediction(
1301 indoc! {"
1302 fn main() {
1303 let story = \"the quick\"
1304 }
1305 "},
1306 indoc! {"
1307 <|editable_region_start|>
1308 fn main() {
1309 let story = \"the quick brown fox jumps over the lazy dog\";
1310 }
1311
1312 <|editable_region_end|>
1313 "},
1314 cx,
1315 )
1316 .await,
1317 indoc! {"
1318 fn main() {
1319 let story = \"the quick brown fox jumps over the lazy dog\";
1320 }
1321 "},
1322 );
1323}
1324
1325#[gpui::test]
1326async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1327 init_test(cx);
1328
1329 let buffer_content = "lorem\n";
1330 let completion_response = indoc! {"
1331 ```animals.js
1332 <|start_of_file|>
1333 <|editable_region_start|>
1334 lorem
1335 ipsum
1336 <|editable_region_end|>
1337 ```"};
1338
1339 assert_eq!(
1340 apply_edit_prediction(buffer_content, completion_response, cx).await,
1341 "lorem\nipsum"
1342 );
1343}
1344
1345#[gpui::test]
1346async fn test_can_collect_data(cx: &mut TestAppContext) {
1347 init_test(cx);
1348
1349 let fs = project::FakeFs::new(cx.executor());
1350 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1351 .await;
1352
1353 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1354 let buffer = project
1355 .update(cx, |project, cx| {
1356 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1357 })
1358 .await
1359 .unwrap();
1360
1361 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1362 ep_store.update(cx, |ep_store, _cx| {
1363 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1364 });
1365
1366 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1367 assert_eq!(
1368 captured_request.lock().clone().unwrap().can_collect_data,
1369 true
1370 );
1371
1372 ep_store.update(cx, |ep_store, _cx| {
1373 ep_store.data_collection_choice = DataCollectionChoice::Disabled
1374 });
1375
1376 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1377 assert_eq!(
1378 captured_request.lock().clone().unwrap().can_collect_data,
1379 false
1380 );
1381}
1382
1383#[gpui::test]
1384async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1385 init_test(cx);
1386
1387 let fs = project::FakeFs::new(cx.executor());
1388 let project = Project::test(fs.clone(), [], cx).await;
1389
1390 let buffer = cx.new(|_cx| {
1391 Buffer::remote(
1392 language::BufferId::new(1).unwrap(),
1393 ReplicaId::new(1),
1394 language::Capability::ReadWrite,
1395 "fn main() {\n println!(\"Hello\");\n}",
1396 )
1397 });
1398
1399 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1400 ep_store.update(cx, |ep_store, _cx| {
1401 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1402 });
1403
1404 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1405 assert_eq!(
1406 captured_request.lock().clone().unwrap().can_collect_data,
1407 false
1408 );
1409}
1410
1411#[gpui::test]
1412async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1413 init_test(cx);
1414
1415 let fs = project::FakeFs::new(cx.executor());
1416 fs.insert_tree(
1417 path!("/project"),
1418 json!({
1419 "LICENSE": BSD_0_TXT,
1420 ".env": "SECRET_KEY=secret"
1421 }),
1422 )
1423 .await;
1424
1425 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1426 let buffer = project
1427 .update(cx, |project, cx| {
1428 project.open_local_buffer("/project/.env", cx)
1429 })
1430 .await
1431 .unwrap();
1432
1433 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1434 ep_store.update(cx, |ep_store, _cx| {
1435 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1436 });
1437
1438 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1439 assert_eq!(
1440 captured_request.lock().clone().unwrap().can_collect_data,
1441 false
1442 );
1443}
1444
1445#[gpui::test]
1446async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1447 init_test(cx);
1448
1449 let fs = project::FakeFs::new(cx.executor());
1450 let project = Project::test(fs.clone(), [], cx).await;
1451 let buffer = cx.new(|cx| Buffer::local("", cx));
1452
1453 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1454 ep_store.update(cx, |ep_store, _cx| {
1455 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1456 });
1457
1458 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1459 assert_eq!(
1460 captured_request.lock().clone().unwrap().can_collect_data,
1461 false
1462 );
1463}
1464
1465#[gpui::test]
1466async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1467 init_test(cx);
1468
1469 let fs = project::FakeFs::new(cx.executor());
1470 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1471 .await;
1472
1473 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1474 let buffer = project
1475 .update(cx, |project, cx| {
1476 project.open_local_buffer("/project/main.rs", cx)
1477 })
1478 .await
1479 .unwrap();
1480
1481 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1482 ep_store.update(cx, |ep_store, _cx| {
1483 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1484 });
1485
1486 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1487 assert_eq!(
1488 captured_request.lock().clone().unwrap().can_collect_data,
1489 false
1490 );
1491}
1492
1493#[gpui::test]
1494async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1495 init_test(cx);
1496
1497 let fs = project::FakeFs::new(cx.executor());
1498 fs.insert_tree(
1499 path!("/open_source_worktree"),
1500 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1501 )
1502 .await;
1503 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1504 .await;
1505
1506 let project = Project::test(
1507 fs.clone(),
1508 [
1509 path!("/open_source_worktree").as_ref(),
1510 path!("/closed_source_worktree").as_ref(),
1511 ],
1512 cx,
1513 )
1514 .await;
1515 let buffer = project
1516 .update(cx, |project, cx| {
1517 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1518 })
1519 .await
1520 .unwrap();
1521
1522 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1523 ep_store.update(cx, |ep_store, _cx| {
1524 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1525 });
1526
1527 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1528 assert_eq!(
1529 captured_request.lock().clone().unwrap().can_collect_data,
1530 true
1531 );
1532
1533 let closed_source_file = project
1534 .update(cx, |project, cx| {
1535 let worktree2 = project
1536 .worktree_for_root_name("closed_source_worktree", cx)
1537 .unwrap();
1538 worktree2.update(cx, |worktree2, cx| {
1539 worktree2.load_file(rel_path("main.rs"), cx)
1540 })
1541 })
1542 .await
1543 .unwrap()
1544 .file;
1545
1546 buffer.update(cx, |buffer, cx| {
1547 buffer.file_updated(closed_source_file, cx);
1548 });
1549
1550 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1551 assert_eq!(
1552 captured_request.lock().clone().unwrap().can_collect_data,
1553 false
1554 );
1555}
1556
1557#[gpui::test]
1558async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1559 init_test(cx);
1560
1561 let fs = project::FakeFs::new(cx.executor());
1562 fs.insert_tree(
1563 path!("/worktree1"),
1564 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1565 )
1566 .await;
1567 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1568 .await;
1569
1570 let project = Project::test(
1571 fs.clone(),
1572 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1573 cx,
1574 )
1575 .await;
1576 let buffer = project
1577 .update(cx, |project, cx| {
1578 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1579 })
1580 .await
1581 .unwrap();
1582 let private_buffer = project
1583 .update(cx, |project, cx| {
1584 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1585 })
1586 .await
1587 .unwrap();
1588
1589 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1590 ep_store.update(cx, |ep_store, _cx| {
1591 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1592 });
1593
1594 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1595 assert_eq!(
1596 captured_request.lock().clone().unwrap().can_collect_data,
1597 true
1598 );
1599
1600 // this has a side effect of registering the buffer to watch for edits
1601 run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1602 assert_eq!(
1603 captured_request.lock().clone().unwrap().can_collect_data,
1604 false
1605 );
1606
1607 private_buffer.update(cx, |private_buffer, cx| {
1608 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1609 });
1610
1611 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1612 assert_eq!(
1613 captured_request.lock().clone().unwrap().can_collect_data,
1614 false
1615 );
1616
1617 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1618 // included
1619 buffer.update(cx, |buffer, cx| {
1620 buffer.edit(
1621 [(
1622 0..0,
1623 " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1624 )],
1625 None,
1626 cx,
1627 );
1628 });
1629
1630 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1631 assert_eq!(
1632 captured_request.lock().clone().unwrap().can_collect_data,
1633 true
1634 );
1635}
1636
1637fn init_test(cx: &mut TestAppContext) {
1638 cx.update(|cx| {
1639 let settings_store = SettingsStore::test(cx);
1640 cx.set_global(settings_store);
1641 });
1642}
1643
1644async fn apply_edit_prediction(
1645 buffer_content: &str,
1646 completion_response: &str,
1647 cx: &mut TestAppContext,
1648) -> String {
1649 let fs = project::FakeFs::new(cx.executor());
1650 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1651 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1652 let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1653 *response.lock() = completion_response.to_string();
1654 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1655 buffer.update(cx, |buffer, cx| {
1656 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1657 });
1658 buffer.read_with(cx, |buffer, _| buffer.text())
1659}
1660
1661async fn run_edit_prediction(
1662 buffer: &Entity<Buffer>,
1663 project: &Entity<Project>,
1664 ep_store: &Entity<EditPredictionStore>,
1665 cx: &mut TestAppContext,
1666) -> EditPrediction {
1667 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1668 ep_store.update(cx, |ep_store, cx| {
1669 ep_store.register_buffer(buffer, &project, cx)
1670 });
1671 cx.background_executor.run_until_parked();
1672 let prediction_task = ep_store.update(cx, |ep_store, cx| {
1673 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
1674 });
1675 prediction_task.await.unwrap().unwrap().prediction.unwrap()
1676}
1677
1678async fn make_test_ep_store(
1679 project: &Entity<Project>,
1680 cx: &mut TestAppContext,
1681) -> (
1682 Entity<EditPredictionStore>,
1683 Arc<Mutex<Option<PredictEditsBody>>>,
1684 Arc<Mutex<String>>,
1685) {
1686 let default_response = indoc! {"
1687 ```main.rs
1688 <|start_of_file|>
1689 <|editable_region_start|>
1690 hello world
1691 <|editable_region_end|>
1692 ```"
1693 };
1694 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
1695 let completion_response: Arc<Mutex<String>> =
1696 Arc::new(Mutex::new(default_response.to_string()));
1697 let http_client = FakeHttpClient::create({
1698 let captured_request = captured_request.clone();
1699 let completion_response = completion_response.clone();
1700 let mut next_request_id = 0;
1701 move |req| {
1702 let captured_request = captured_request.clone();
1703 let completion_response = completion_response.clone();
1704 async move {
1705 match (req.method(), req.uri().path()) {
1706 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1707 .status(200)
1708 .body(
1709 serde_json::to_string(&CreateLlmTokenResponse {
1710 token: LlmToken("the-llm-token".to_string()),
1711 })
1712 .unwrap()
1713 .into(),
1714 )
1715 .unwrap()),
1716 (&Method::POST, "/predict_edits/v2") => {
1717 let mut request_body = String::new();
1718 req.into_body().read_to_string(&mut request_body).await?;
1719 *captured_request.lock() =
1720 Some(serde_json::from_str(&request_body).unwrap());
1721 next_request_id += 1;
1722 Ok(http_client::Response::builder()
1723 .status(200)
1724 .body(
1725 serde_json::to_string(&PredictEditsResponse {
1726 request_id: format!("request-{next_request_id}"),
1727 output_excerpt: completion_response.lock().clone(),
1728 })
1729 .unwrap()
1730 .into(),
1731 )
1732 .unwrap())
1733 }
1734 _ => Ok(http_client::Response::builder()
1735 .status(404)
1736 .body("Not Found".into())
1737 .unwrap()),
1738 }
1739 }
1740 }
1741 });
1742
1743 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1744 cx.update(|cx| {
1745 RefreshLlmTokenListener::register(client.clone(), cx);
1746 });
1747 let _server = FakeServer::for_client(42, &client, cx).await;
1748
1749 let ep_store = cx.new(|cx| {
1750 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
1751 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
1752
1753 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
1754 for worktree in worktrees {
1755 let worktree_id = worktree.read(cx).id();
1756 ep_store
1757 .get_or_init_project(project, cx)
1758 .license_detection_watchers
1759 .entry(worktree_id)
1760 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
1761 }
1762
1763 ep_store
1764 });
1765
1766 (ep_store, captured_request, completion_response)
1767}
1768
1769fn to_completion_edits(
1770 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
1771 buffer: &Entity<Buffer>,
1772 cx: &App,
1773) -> Vec<(Range<Anchor>, Arc<str>)> {
1774 let buffer = buffer.read(cx);
1775 iterator
1776 .into_iter()
1777 .map(|(range, text)| {
1778 (
1779 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1780 text,
1781 )
1782 })
1783 .collect()
1784}
1785
1786fn from_completion_edits(
1787 editor_edits: &[(Range<Anchor>, Arc<str>)],
1788 buffer: &Entity<Buffer>,
1789 cx: &App,
1790) -> Vec<(Range<usize>, Arc<str>)> {
1791 let buffer = buffer.read(cx);
1792 editor_edits
1793 .iter()
1794 .map(|(range, text)| {
1795 (
1796 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1797 text.clone(),
1798 )
1799 })
1800 .collect()
1801}
1802
1803#[ctor::ctor]
1804fn init_logger() {
1805 zlog::init_test();
1806}