1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
3use client::{UserStore, test::FakeServer};
4use clock::{FakeSystemClock, ReplicaId};
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
8 RejectEditPredictionsBody,
9};
10use futures::{
11 AsyncReadExt, StreamExt,
12 channel::{mpsc, oneshot},
13};
14use gpui::{
15 Entity, TestAppContext,
16 http_client::{FakeHttpClient, Response},
17};
18use indoc::indoc;
19use language::{Point, ToOffset as _};
20use lsp::LanguageServerId;
21use open_ai::Usage;
22use parking_lot::Mutex;
23use pretty_assertions::{assert_eq, assert_matches};
24use project::{FakeFs, Project};
25use serde_json::json;
26use settings::SettingsStore;
27use std::{path::Path, sync::Arc, time::Duration};
28use util::{path, rel_path::rel_path};
29use uuid::Uuid;
30use zeta_prompt::ZetaPromptInput;
31
32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
33
34#[gpui::test]
35async fn test_current_state(cx: &mut TestAppContext) {
36 let (ep_store, mut requests) = init_test_with_fake_client(cx);
37 let fs = FakeFs::new(cx.executor());
38 fs.insert_tree(
39 "/root",
40 json!({
41 "1.txt": "Hello!\nHow\nBye\n",
42 "2.txt": "Hola!\nComo\nAdios\n"
43 }),
44 )
45 .await;
46 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
47
48 let buffer1 = project
49 .update(cx, |project, cx| {
50 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
51 project.set_active_path(Some(path.clone()), cx);
52 project.open_buffer(path, cx)
53 })
54 .await
55 .unwrap();
56 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
57 let position = snapshot1.anchor_before(language::Point::new(1, 3));
58
59 ep_store.update(cx, |ep_store, cx| {
60 ep_store.register_project(&project, cx);
61 ep_store.register_buffer(&buffer1, &project, cx);
62 });
63
64 // Prediction for current file
65
66 ep_store.update(cx, |ep_store, cx| {
67 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
68 });
69 let (request, respond_tx) = requests.predict.next().await.unwrap();
70
71 respond_tx
72 .send(model_response(
73 request,
74 indoc! {r"
75 --- a/root/1.txt
76 +++ b/root/1.txt
77 @@ ... @@
78 Hello!
79 -How
80 +How are you?
81 Bye
82 "},
83 ))
84 .unwrap();
85
86 cx.run_until_parked();
87
88 ep_store.update(cx, |ep_store, cx| {
89 let prediction = ep_store
90 .prediction_at(&buffer1, None, &project, cx)
91 .unwrap();
92 assert_matches!(prediction, BufferEditPrediction::Local { .. });
93 });
94
95 ep_store.update(cx, |ep_store, _cx| {
96 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
97 });
98
99 // Prediction for diagnostic in another file
100
101 let diagnostic = lsp::Diagnostic {
102 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
103 severity: Some(lsp::DiagnosticSeverity::ERROR),
104 message: "Sentence is incomplete".to_string(),
105 ..Default::default()
106 };
107
108 project.update(cx, |project, cx| {
109 project.lsp_store().update(cx, |lsp_store, cx| {
110 lsp_store
111 .update_diagnostics(
112 LanguageServerId(0),
113 lsp::PublishDiagnosticsParams {
114 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
115 diagnostics: vec![diagnostic],
116 version: None,
117 },
118 None,
119 language::DiagnosticSourceKind::Pushed,
120 &[],
121 cx,
122 )
123 .unwrap();
124 });
125 });
126
127 let (request, respond_tx) = requests.predict.next().await.unwrap();
128 respond_tx
129 .send(model_response(
130 request,
131 indoc! {r#"
132 --- a/root/2.txt
133 +++ b/root/2.txt
134 @@ ... @@
135 Hola!
136 -Como
137 +Como estas?
138 Adios
139 "#},
140 ))
141 .unwrap();
142 cx.run_until_parked();
143
144 ep_store.update(cx, |ep_store, cx| {
145 let prediction = ep_store
146 .prediction_at(&buffer1, None, &project, cx)
147 .unwrap();
148 assert_matches!(
149 prediction,
150 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
151 );
152 });
153
154 let buffer2 = project
155 .update(cx, |project, cx| {
156 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
157 project.open_buffer(path, cx)
158 })
159 .await
160 .unwrap();
161
162 ep_store.update(cx, |ep_store, cx| {
163 let prediction = ep_store
164 .prediction_at(&buffer2, None, &project, cx)
165 .unwrap();
166 assert_matches!(prediction, BufferEditPrediction::Local { .. });
167 });
168}
169
170#[gpui::test]
171async fn test_simple_request(cx: &mut TestAppContext) {
172 let (ep_store, mut requests) = init_test_with_fake_client(cx);
173 let fs = FakeFs::new(cx.executor());
174 fs.insert_tree(
175 "/root",
176 json!({
177 "foo.md": "Hello!\nHow\nBye\n"
178 }),
179 )
180 .await;
181 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
182
183 let buffer = project
184 .update(cx, |project, cx| {
185 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
186 project.open_buffer(path, cx)
187 })
188 .await
189 .unwrap();
190 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
191 let position = snapshot.anchor_before(language::Point::new(1, 3));
192
193 let prediction_task = ep_store.update(cx, |ep_store, cx| {
194 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
195 });
196
197 let (request, respond_tx) = requests.predict.next().await.unwrap();
198
199 // TODO Put back when we have a structured request again
200 // assert_eq!(
201 // request.excerpt_path.as_ref(),
202 // Path::new(path!("root/foo.md"))
203 // );
204 // assert_eq!(
205 // request.cursor_point,
206 // Point {
207 // line: Line(1),
208 // column: 3
209 // }
210 // );
211
212 respond_tx
213 .send(model_response(
214 request,
215 indoc! { r"
216 --- a/root/foo.md
217 +++ b/root/foo.md
218 @@ ... @@
219 Hello!
220 -How
221 +How are you?
222 Bye
223 "},
224 ))
225 .unwrap();
226
227 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
228
229 assert_eq!(prediction.edits.len(), 1);
230 assert_eq!(
231 prediction.edits[0].0.to_point(&snapshot).start,
232 language::Point::new(1, 3)
233 );
234 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
235}
236
237#[gpui::test]
238async fn test_request_events(cx: &mut TestAppContext) {
239 let (ep_store, mut requests) = init_test_with_fake_client(cx);
240 let fs = FakeFs::new(cx.executor());
241 fs.insert_tree(
242 "/root",
243 json!({
244 "foo.md": "Hello!\n\nBye\n"
245 }),
246 )
247 .await;
248 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
249
250 let buffer = project
251 .update(cx, |project, cx| {
252 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
253 project.open_buffer(path, cx)
254 })
255 .await
256 .unwrap();
257
258 ep_store.update(cx, |ep_store, cx| {
259 ep_store.register_buffer(&buffer, &project, cx);
260 });
261
262 buffer.update(cx, |buffer, cx| {
263 buffer.edit(vec![(7..7, "How")], None, cx);
264 });
265
266 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
267 let position = snapshot.anchor_before(language::Point::new(1, 3));
268
269 let prediction_task = ep_store.update(cx, |ep_store, cx| {
270 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
271 });
272
273 let (request, respond_tx) = requests.predict.next().await.unwrap();
274
275 let prompt = prompt_from_request(&request);
276 assert!(
277 prompt.contains(indoc! {"
278 --- a/root/foo.md
279 +++ b/root/foo.md
280 @@ -1,3 +1,3 @@
281 Hello!
282 -
283 +How
284 Bye
285 "}),
286 "{prompt}"
287 );
288
289 respond_tx
290 .send(model_response(
291 request,
292 indoc! {r#"
293 --- a/root/foo.md
294 +++ b/root/foo.md
295 @@ ... @@
296 Hello!
297 -How
298 +How are you?
299 Bye
300 "#},
301 ))
302 .unwrap();
303
304 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
305
306 assert_eq!(prediction.edits.len(), 1);
307 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
308}
309
310#[gpui::test]
311async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
312 let (ep_store, _requests) = init_test_with_fake_client(cx);
313 let fs = FakeFs::new(cx.executor());
314 fs.insert_tree(
315 "/root",
316 json!({
317 "foo.md": "Hello!\n\nBye\n"
318 }),
319 )
320 .await;
321 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
322
323 let buffer = project
324 .update(cx, |project, cx| {
325 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
326 project.open_buffer(path, cx)
327 })
328 .await
329 .unwrap();
330
331 ep_store.update(cx, |ep_store, cx| {
332 ep_store.register_buffer(&buffer, &project, cx);
333 });
334
335 // First burst: insert "How"
336 buffer.update(cx, |buffer, cx| {
337 buffer.edit(vec![(7..7, "How")], None, cx);
338 });
339
340 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
341 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
342 cx.run_until_parked();
343
344 // Second burst: append " are you?" immediately after "How" on the same line.
345 //
346 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
347 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
348 buffer.update(cx, |buffer, cx| {
349 buffer.edit(vec![(10..10, " are you?")], None, cx);
350 });
351
352 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
353 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
354 buffer.update(cx, |buffer, cx| {
355 buffer.edit(vec![(19..19, "!")], None, cx);
356 });
357
358 // Without time-based splitting, there is one event.
359 let events = ep_store.update(cx, |ep_store, cx| {
360 ep_store.edit_history_for_project(&project, cx)
361 });
362 assert_eq!(events.len(), 1);
363 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
364 assert_eq!(
365 diff.as_str(),
366 indoc! {"
367 @@ -1,3 +1,3 @@
368 Hello!
369 -
370 +How are you?!
371 Bye
372 "}
373 );
374
375 // With time-based splitting, there are two distinct events.
376 let events = ep_store.update(cx, |ep_store, cx| {
377 ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
378 });
379 assert_eq!(events.len(), 2);
380 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
381 assert_eq!(
382 diff.as_str(),
383 indoc! {"
384 @@ -1,3 +1,3 @@
385 Hello!
386 -
387 +How
388 Bye
389 "}
390 );
391
392 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
393 assert_eq!(
394 diff.as_str(),
395 indoc! {"
396 @@ -1,3 +1,3 @@
397 Hello!
398 -How
399 +How are you?!
400 Bye
401 "}
402 );
403}
404
405#[gpui::test]
406async fn test_empty_prediction(cx: &mut TestAppContext) {
407 let (ep_store, mut requests) = init_test_with_fake_client(cx);
408 let fs = FakeFs::new(cx.executor());
409 fs.insert_tree(
410 "/root",
411 json!({
412 "foo.md": "Hello!\nHow\nBye\n"
413 }),
414 )
415 .await;
416 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
417
418 let buffer = project
419 .update(cx, |project, cx| {
420 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
421 project.open_buffer(path, cx)
422 })
423 .await
424 .unwrap();
425 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
426 let position = snapshot.anchor_before(language::Point::new(1, 3));
427
428 ep_store.update(cx, |ep_store, cx| {
429 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
430 });
431
432 let (request, respond_tx) = requests.predict.next().await.unwrap();
433 let response = model_response(request, "");
434 let id = response.id.clone();
435 respond_tx.send(response).unwrap();
436
437 cx.run_until_parked();
438
439 ep_store.update(cx, |ep_store, cx| {
440 assert!(
441 ep_store
442 .prediction_at(&buffer, None, &project, cx)
443 .is_none()
444 );
445 });
446
447 // prediction is reported as rejected
448 let (reject_request, _) = requests.reject.next().await.unwrap();
449
450 assert_eq!(
451 &reject_request.rejections,
452 &[EditPredictionRejection {
453 request_id: id,
454 reason: EditPredictionRejectReason::Empty,
455 was_shown: false
456 }]
457 );
458}
459
460#[gpui::test]
461async fn test_interpolated_empty(cx: &mut TestAppContext) {
462 let (ep_store, mut requests) = init_test_with_fake_client(cx);
463 let fs = FakeFs::new(cx.executor());
464 fs.insert_tree(
465 "/root",
466 json!({
467 "foo.md": "Hello!\nHow\nBye\n"
468 }),
469 )
470 .await;
471 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
472
473 let buffer = project
474 .update(cx, |project, cx| {
475 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
476 project.open_buffer(path, cx)
477 })
478 .await
479 .unwrap();
480 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
481 let position = snapshot.anchor_before(language::Point::new(1, 3));
482
483 ep_store.update(cx, |ep_store, cx| {
484 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
485 });
486
487 let (request, respond_tx) = requests.predict.next().await.unwrap();
488
489 buffer.update(cx, |buffer, cx| {
490 buffer.set_text("Hello!\nHow are you?\nBye", cx);
491 });
492
493 let response = model_response(request, SIMPLE_DIFF);
494 let id = response.id.clone();
495 respond_tx.send(response).unwrap();
496
497 cx.run_until_parked();
498
499 ep_store.update(cx, |ep_store, cx| {
500 assert!(
501 ep_store
502 .prediction_at(&buffer, None, &project, cx)
503 .is_none()
504 );
505 });
506
507 // prediction is reported as rejected
508 let (reject_request, _) = requests.reject.next().await.unwrap();
509
510 assert_eq!(
511 &reject_request.rejections,
512 &[EditPredictionRejection {
513 request_id: id,
514 reason: EditPredictionRejectReason::InterpolatedEmpty,
515 was_shown: false
516 }]
517 );
518}
519
520const SIMPLE_DIFF: &str = indoc! { r"
521 --- a/root/foo.md
522 +++ b/root/foo.md
523 @@ ... @@
524 Hello!
525 -How
526 +How are you?
527 Bye
528"};
529
530#[gpui::test]
531async fn test_replace_current(cx: &mut TestAppContext) {
532 let (ep_store, mut requests) = init_test_with_fake_client(cx);
533 let fs = FakeFs::new(cx.executor());
534 fs.insert_tree(
535 "/root",
536 json!({
537 "foo.md": "Hello!\nHow\nBye\n"
538 }),
539 )
540 .await;
541 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
542
543 let buffer = project
544 .update(cx, |project, cx| {
545 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
546 project.open_buffer(path, cx)
547 })
548 .await
549 .unwrap();
550 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
551 let position = snapshot.anchor_before(language::Point::new(1, 3));
552
553 ep_store.update(cx, |ep_store, cx| {
554 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
555 });
556
557 let (request, respond_tx) = requests.predict.next().await.unwrap();
558 let first_response = model_response(request, SIMPLE_DIFF);
559 let first_id = first_response.id.clone();
560 respond_tx.send(first_response).unwrap();
561
562 cx.run_until_parked();
563
564 ep_store.update(cx, |ep_store, cx| {
565 assert_eq!(
566 ep_store
567 .prediction_at(&buffer, None, &project, cx)
568 .unwrap()
569 .id
570 .0,
571 first_id
572 );
573 });
574
575 // a second request is triggered
576 ep_store.update(cx, |ep_store, cx| {
577 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
578 });
579
580 let (request, respond_tx) = requests.predict.next().await.unwrap();
581 let second_response = model_response(request, SIMPLE_DIFF);
582 let second_id = second_response.id.clone();
583 respond_tx.send(second_response).unwrap();
584
585 cx.run_until_parked();
586
587 ep_store.update(cx, |ep_store, cx| {
588 // second replaces first
589 assert_eq!(
590 ep_store
591 .prediction_at(&buffer, None, &project, cx)
592 .unwrap()
593 .id
594 .0,
595 second_id
596 );
597 });
598
599 // first is reported as replaced
600 let (reject_request, _) = requests.reject.next().await.unwrap();
601
602 assert_eq!(
603 &reject_request.rejections,
604 &[EditPredictionRejection {
605 request_id: first_id,
606 reason: EditPredictionRejectReason::Replaced,
607 was_shown: false
608 }]
609 );
610}
611
612#[gpui::test]
613async fn test_current_preferred(cx: &mut TestAppContext) {
614 let (ep_store, mut requests) = init_test_with_fake_client(cx);
615 let fs = FakeFs::new(cx.executor());
616 fs.insert_tree(
617 "/root",
618 json!({
619 "foo.md": "Hello!\nHow\nBye\n"
620 }),
621 )
622 .await;
623 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
624
625 let buffer = project
626 .update(cx, |project, cx| {
627 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
628 project.open_buffer(path, cx)
629 })
630 .await
631 .unwrap();
632 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
633 let position = snapshot.anchor_before(language::Point::new(1, 3));
634
635 ep_store.update(cx, |ep_store, cx| {
636 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
637 });
638
639 let (request, respond_tx) = requests.predict.next().await.unwrap();
640 let first_response = model_response(request, SIMPLE_DIFF);
641 let first_id = first_response.id.clone();
642 respond_tx.send(first_response).unwrap();
643
644 cx.run_until_parked();
645
646 ep_store.update(cx, |ep_store, cx| {
647 assert_eq!(
648 ep_store
649 .prediction_at(&buffer, None, &project, cx)
650 .unwrap()
651 .id
652 .0,
653 first_id
654 );
655 });
656
657 // a second request is triggered
658 ep_store.update(cx, |ep_store, cx| {
659 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
660 });
661
662 let (request, respond_tx) = requests.predict.next().await.unwrap();
663 // worse than current prediction
664 let second_response = model_response(
665 request,
666 indoc! { r"
667 --- a/root/foo.md
668 +++ b/root/foo.md
669 @@ ... @@
670 Hello!
671 -How
672 +How are
673 Bye
674 "},
675 );
676 let second_id = second_response.id.clone();
677 respond_tx.send(second_response).unwrap();
678
679 cx.run_until_parked();
680
681 ep_store.update(cx, |ep_store, cx| {
682 // first is preferred over second
683 assert_eq!(
684 ep_store
685 .prediction_at(&buffer, None, &project, cx)
686 .unwrap()
687 .id
688 .0,
689 first_id
690 );
691 });
692
693 // second is reported as rejected
694 let (reject_request, _) = requests.reject.next().await.unwrap();
695
696 assert_eq!(
697 &reject_request.rejections,
698 &[EditPredictionRejection {
699 request_id: second_id,
700 reason: EditPredictionRejectReason::CurrentPreferred,
701 was_shown: false
702 }]
703 );
704}
705
706#[gpui::test]
707async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
708 let (ep_store, mut requests) = init_test_with_fake_client(cx);
709 let fs = FakeFs::new(cx.executor());
710 fs.insert_tree(
711 "/root",
712 json!({
713 "foo.md": "Hello!\nHow\nBye\n"
714 }),
715 )
716 .await;
717 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
718
719 let buffer = project
720 .update(cx, |project, cx| {
721 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
722 project.open_buffer(path, cx)
723 })
724 .await
725 .unwrap();
726 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
727 let position = snapshot.anchor_before(language::Point::new(1, 3));
728
729 // start two refresh tasks
730 ep_store.update(cx, |ep_store, cx| {
731 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
732 });
733
734 let (request1, respond_first) = requests.predict.next().await.unwrap();
735
736 ep_store.update(cx, |ep_store, cx| {
737 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
738 });
739
740 let (request, respond_second) = requests.predict.next().await.unwrap();
741
742 // wait for throttle
743 cx.run_until_parked();
744
745 // second responds first
746 let second_response = model_response(request, SIMPLE_DIFF);
747 let second_id = second_response.id.clone();
748 respond_second.send(second_response).unwrap();
749
750 cx.run_until_parked();
751
752 ep_store.update(cx, |ep_store, cx| {
753 // current prediction is second
754 assert_eq!(
755 ep_store
756 .prediction_at(&buffer, None, &project, cx)
757 .unwrap()
758 .id
759 .0,
760 second_id
761 );
762 });
763
764 let first_response = model_response(request1, SIMPLE_DIFF);
765 let first_id = first_response.id.clone();
766 respond_first.send(first_response).unwrap();
767
768 cx.run_until_parked();
769
770 ep_store.update(cx, |ep_store, cx| {
771 // current prediction is still second, since first was cancelled
772 assert_eq!(
773 ep_store
774 .prediction_at(&buffer, None, &project, cx)
775 .unwrap()
776 .id
777 .0,
778 second_id
779 );
780 });
781
782 // first is reported as rejected
783 let (reject_request, _) = requests.reject.next().await.unwrap();
784
785 cx.run_until_parked();
786
787 assert_eq!(
788 &reject_request.rejections,
789 &[EditPredictionRejection {
790 request_id: first_id,
791 reason: EditPredictionRejectReason::Canceled,
792 was_shown: false
793 }]
794 );
795}
796
797#[gpui::test]
798async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
799 let (ep_store, mut requests) = init_test_with_fake_client(cx);
800 let fs = FakeFs::new(cx.executor());
801 fs.insert_tree(
802 "/root",
803 json!({
804 "foo.md": "Hello!\nHow\nBye\n"
805 }),
806 )
807 .await;
808 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
809
810 let buffer = project
811 .update(cx, |project, cx| {
812 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
813 project.open_buffer(path, cx)
814 })
815 .await
816 .unwrap();
817 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
818 let position = snapshot.anchor_before(language::Point::new(1, 3));
819
820 // start two refresh tasks
821 ep_store.update(cx, |ep_store, cx| {
822 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
823 });
824
825 let (request1, respond_first) = requests.predict.next().await.unwrap();
826
827 ep_store.update(cx, |ep_store, cx| {
828 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
829 });
830
831 let (request2, respond_second) = requests.predict.next().await.unwrap();
832
833 // wait for throttle, so requests are sent
834 cx.run_until_parked();
835
836 ep_store.update(cx, |ep_store, cx| {
837 // start a third request
838 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
839
840 // 2 are pending, so 2nd is cancelled
841 assert_eq!(
842 ep_store
843 .get_or_init_project(&project, cx)
844 .cancelled_predictions
845 .iter()
846 .copied()
847 .collect::<Vec<_>>(),
848 [1]
849 );
850 });
851
852 // wait for throttle
853 cx.run_until_parked();
854
855 let (request3, respond_third) = requests.predict.next().await.unwrap();
856
857 let first_response = model_response(request1, SIMPLE_DIFF);
858 let first_id = first_response.id.clone();
859 respond_first.send(first_response).unwrap();
860
861 cx.run_until_parked();
862
863 ep_store.update(cx, |ep_store, cx| {
864 // current prediction is first
865 assert_eq!(
866 ep_store
867 .prediction_at(&buffer, None, &project, cx)
868 .unwrap()
869 .id
870 .0,
871 first_id
872 );
873 });
874
875 let cancelled_response = model_response(request2, SIMPLE_DIFF);
876 let cancelled_id = cancelled_response.id.clone();
877 respond_second.send(cancelled_response).unwrap();
878
879 cx.run_until_parked();
880
881 ep_store.update(cx, |ep_store, cx| {
882 // current prediction is still first, since second was cancelled
883 assert_eq!(
884 ep_store
885 .prediction_at(&buffer, None, &project, cx)
886 .unwrap()
887 .id
888 .0,
889 first_id
890 );
891 });
892
893 let third_response = model_response(request3, SIMPLE_DIFF);
894 let third_response_id = third_response.id.clone();
895 respond_third.send(third_response).unwrap();
896
897 cx.run_until_parked();
898
899 ep_store.update(cx, |ep_store, cx| {
900 // third completes and replaces first
901 assert_eq!(
902 ep_store
903 .prediction_at(&buffer, None, &project, cx)
904 .unwrap()
905 .id
906 .0,
907 third_response_id
908 );
909 });
910
911 // second is reported as rejected
912 let (reject_request, _) = requests.reject.next().await.unwrap();
913
914 cx.run_until_parked();
915
916 assert_eq!(
917 &reject_request.rejections,
918 &[
919 EditPredictionRejection {
920 request_id: cancelled_id,
921 reason: EditPredictionRejectReason::Canceled,
922 was_shown: false
923 },
924 EditPredictionRejection {
925 request_id: first_id,
926 reason: EditPredictionRejectReason::Replaced,
927 was_shown: false
928 }
929 ]
930 );
931}
932
933#[gpui::test]
934async fn test_rejections_flushing(cx: &mut TestAppContext) {
935 let (ep_store, mut requests) = init_test_with_fake_client(cx);
936
937 ep_store.update(cx, |ep_store, _cx| {
938 ep_store.reject_prediction(
939 EditPredictionId("test-1".into()),
940 EditPredictionRejectReason::Discarded,
941 false,
942 );
943 ep_store.reject_prediction(
944 EditPredictionId("test-2".into()),
945 EditPredictionRejectReason::Canceled,
946 true,
947 );
948 });
949
950 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
951 cx.run_until_parked();
952
953 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
954 respond_tx.send(()).unwrap();
955
956 // batched
957 assert_eq!(reject_request.rejections.len(), 2);
958 assert_eq!(
959 reject_request.rejections[0],
960 EditPredictionRejection {
961 request_id: "test-1".to_string(),
962 reason: EditPredictionRejectReason::Discarded,
963 was_shown: false
964 }
965 );
966 assert_eq!(
967 reject_request.rejections[1],
968 EditPredictionRejection {
969 request_id: "test-2".to_string(),
970 reason: EditPredictionRejectReason::Canceled,
971 was_shown: true
972 }
973 );
974
975 // Reaching batch size limit sends without debounce
976 ep_store.update(cx, |ep_store, _cx| {
977 for i in 0..70 {
978 ep_store.reject_prediction(
979 EditPredictionId(format!("batch-{}", i).into()),
980 EditPredictionRejectReason::Discarded,
981 false,
982 );
983 }
984 });
985
986 // First MAX/2 items are sent immediately
987 cx.run_until_parked();
988 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
989 respond_tx.send(()).unwrap();
990
991 assert_eq!(reject_request.rejections.len(), 50);
992 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
993 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
994
995 // Remaining items are debounced with the next batch
996 cx.executor().advance_clock(Duration::from_secs(15));
997 cx.run_until_parked();
998
999 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1000 respond_tx.send(()).unwrap();
1001
1002 assert_eq!(reject_request.rejections.len(), 20);
1003 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1004 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1005
1006 // Request failure
1007 ep_store.update(cx, |ep_store, _cx| {
1008 ep_store.reject_prediction(
1009 EditPredictionId("retry-1".into()),
1010 EditPredictionRejectReason::Discarded,
1011 false,
1012 );
1013 });
1014
1015 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1016 cx.run_until_parked();
1017
1018 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1019 assert_eq!(reject_request.rejections.len(), 1);
1020 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1021 // Simulate failure
1022 drop(_respond_tx);
1023
1024 // Add another rejection
1025 ep_store.update(cx, |ep_store, _cx| {
1026 ep_store.reject_prediction(
1027 EditPredictionId("retry-2".into()),
1028 EditPredictionRejectReason::Discarded,
1029 false,
1030 );
1031 });
1032
1033 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1034 cx.run_until_parked();
1035
1036 // Retry should include both the failed item and the new one
1037 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1038 respond_tx.send(()).unwrap();
1039
1040 assert_eq!(reject_request.rejections.len(), 2);
1041 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1042 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1043}
1044
1045// Skipped until we start including diagnostics in prompt
1046// #[gpui::test]
1047// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1048// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1049// let fs = FakeFs::new(cx.executor());
1050// fs.insert_tree(
1051// "/root",
1052// json!({
1053// "foo.md": "Hello!\nBye"
1054// }),
1055// )
1056// .await;
1057// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1058
1059// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1060// let diagnostic = lsp::Diagnostic {
1061// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1062// severity: Some(lsp::DiagnosticSeverity::ERROR),
1063// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1064// ..Default::default()
1065// };
1066
1067// project.update(cx, |project, cx| {
1068// project.lsp_store().update(cx, |lsp_store, cx| {
1069// // Create some diagnostics
1070// lsp_store
1071// .update_diagnostics(
1072// LanguageServerId(0),
1073// lsp::PublishDiagnosticsParams {
1074// uri: path_to_buffer_uri.clone(),
1075// diagnostics: vec![diagnostic],
1076// version: None,
1077// },
1078// None,
1079// language::DiagnosticSourceKind::Pushed,
1080// &[],
1081// cx,
1082// )
1083// .unwrap();
1084// });
1085// });
1086
1087// let buffer = project
1088// .update(cx, |project, cx| {
1089// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1090// project.open_buffer(path, cx)
1091// })
1092// .await
1093// .unwrap();
1094
1095// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1096// let position = snapshot.anchor_before(language::Point::new(0, 0));
1097
1098// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1099// ep_store.request_prediction(&project, &buffer, position, cx)
1100// });
1101
1102// let (request, _respond_tx) = req_rx.next().await.unwrap();
1103
1104// assert_eq!(request.diagnostic_groups.len(), 1);
1105// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1106// .unwrap();
1107// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1108// assert_eq!(
1109// value,
1110// json!({
1111// "entries": [{
1112// "range": {
1113// "start": 8,
1114// "end": 10
1115// },
1116// "diagnostic": {
1117// "source": null,
1118// "code": null,
1119// "code_description": null,
1120// "severity": 1,
1121// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1122// "markdown": null,
1123// "group_id": 0,
1124// "is_primary": true,
1125// "is_disk_based": false,
1126// "is_unnecessary": false,
1127// "source_kind": "Pushed",
1128// "data": null,
1129// "underline": true
1130// }
1131// }],
1132// "primary_ix": 0
1133// })
1134// );
1135// }
1136
1137// Generate a model response that would apply the given diff to the active file.
1138fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
1139 let prompt = match &request.messages[0] {
1140 open_ai::RequestMessage::User {
1141 content: open_ai::MessageContent::Plain(content),
1142 } => content,
1143 _ => panic!("unexpected request {request:?}"),
1144 };
1145
1146 let open = "<editable_region>\n";
1147 let close = "</editable_region>";
1148 let cursor = "<|user_cursor|>";
1149
1150 let start_ix = open.len() + prompt.find(open).unwrap();
1151 let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
1152 let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
1153 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1154
1155 open_ai::Response {
1156 id: Uuid::new_v4().to_string(),
1157 object: "response".into(),
1158 created: 0,
1159 model: "model".into(),
1160 choices: vec![open_ai::Choice {
1161 index: 0,
1162 message: open_ai::RequestMessage::Assistant {
1163 content: Some(open_ai::MessageContent::Plain(new_excerpt)),
1164 tool_calls: vec![],
1165 },
1166 finish_reason: None,
1167 }],
1168 usage: Usage {
1169 prompt_tokens: 0,
1170 completion_tokens: 0,
1171 total_tokens: 0,
1172 },
1173 }
1174}
1175
1176fn prompt_from_request(request: &open_ai::Request) -> &str {
1177 assert_eq!(request.messages.len(), 1);
1178 let open_ai::RequestMessage::User {
1179 content: open_ai::MessageContent::Plain(content),
1180 ..
1181 } = &request.messages[0]
1182 else {
1183 panic!(
1184 "Request does not have single user message of type Plain. {:#?}",
1185 request
1186 );
1187 };
1188 content
1189}
1190
1191struct RequestChannels {
1192 predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
1193 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1194}
1195
1196fn init_test_with_fake_client(
1197 cx: &mut TestAppContext,
1198) -> (Entity<EditPredictionStore>, RequestChannels) {
1199 cx.update(move |cx| {
1200 let settings_store = SettingsStore::test(cx);
1201 cx.set_global(settings_store);
1202 zlog::init_test();
1203
1204 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1205 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1206
1207 let http_client = FakeHttpClient::create({
1208 move |req| {
1209 let uri = req.uri().path().to_string();
1210 let mut body = req.into_body();
1211 let predict_req_tx = predict_req_tx.clone();
1212 let reject_req_tx = reject_req_tx.clone();
1213 async move {
1214 let resp = match uri.as_str() {
1215 "/client/llm_tokens" => serde_json::to_string(&json!({
1216 "token": "test"
1217 }))
1218 .unwrap(),
1219 "/predict_edits/raw" => {
1220 let mut buf = Vec::new();
1221 body.read_to_end(&mut buf).await.ok();
1222 let req = serde_json::from_slice(&buf).unwrap();
1223
1224 let (res_tx, res_rx) = oneshot::channel();
1225 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1226 serde_json::to_string(&res_rx.await?).unwrap()
1227 }
1228 "/predict_edits/reject" => {
1229 let mut buf = Vec::new();
1230 body.read_to_end(&mut buf).await.ok();
1231 let req = serde_json::from_slice(&buf).unwrap();
1232
1233 let (res_tx, res_rx) = oneshot::channel();
1234 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1235 serde_json::to_string(&res_rx.await?).unwrap()
1236 }
1237 _ => {
1238 panic!("Unexpected path: {}", uri)
1239 }
1240 };
1241
1242 Ok(Response::builder().body(resp.into()).unwrap())
1243 }
1244 }
1245 });
1246
1247 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1248 client.cloud_client().set_credentials(1, "test".into());
1249
1250 language_model::init(client.clone(), cx);
1251
1252 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1253 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1254
1255 (
1256 ep_store,
1257 RequestChannels {
1258 predict: predict_req_rx,
1259 reject: reject_req_rx,
1260 },
1261 )
1262 })
1263}
1264
1265const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1266
1267#[gpui::test]
1268async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1269 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1270 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1271 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1272 });
1273
1274 let edit_preview = cx
1275 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1276 .await;
1277
1278 let prediction = EditPrediction {
1279 edits,
1280 edit_preview,
1281 buffer: buffer.clone(),
1282 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1283 id: EditPredictionId("the-id".into()),
1284 inputs: ZetaPromptInput {
1285 events: Default::default(),
1286 related_files: Default::default(),
1287 cursor_path: Path::new("").into(),
1288 cursor_excerpt: "".into(),
1289 editable_range_in_excerpt: 0..0,
1290 cursor_offset_in_excerpt: 0,
1291 },
1292 buffer_snapshotted_at: Instant::now(),
1293 response_received_at: Instant::now(),
1294 };
1295
1296 cx.update(|cx| {
1297 assert_eq!(
1298 from_completion_edits(
1299 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1300 &buffer,
1301 cx
1302 ),
1303 vec![(2..5, "REM".into()), (9..11, "".into())]
1304 );
1305
1306 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1307 assert_eq!(
1308 from_completion_edits(
1309 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1310 &buffer,
1311 cx
1312 ),
1313 vec![(2..2, "REM".into()), (6..8, "".into())]
1314 );
1315
1316 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1317 assert_eq!(
1318 from_completion_edits(
1319 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1320 &buffer,
1321 cx
1322 ),
1323 vec![(2..5, "REM".into()), (9..11, "".into())]
1324 );
1325
1326 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1327 assert_eq!(
1328 from_completion_edits(
1329 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1330 &buffer,
1331 cx
1332 ),
1333 vec![(3..3, "EM".into()), (7..9, "".into())]
1334 );
1335
1336 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1337 assert_eq!(
1338 from_completion_edits(
1339 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1340 &buffer,
1341 cx
1342 ),
1343 vec![(4..4, "M".into()), (8..10, "".into())]
1344 );
1345
1346 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1347 assert_eq!(
1348 from_completion_edits(
1349 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1350 &buffer,
1351 cx
1352 ),
1353 vec![(9..11, "".into())]
1354 );
1355
1356 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1357 assert_eq!(
1358 from_completion_edits(
1359 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1360 &buffer,
1361 cx
1362 ),
1363 vec![(4..4, "M".into()), (8..10, "".into())]
1364 );
1365
1366 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1367 assert_eq!(
1368 from_completion_edits(
1369 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1370 &buffer,
1371 cx
1372 ),
1373 vec![(4..4, "M".into())]
1374 );
1375
1376 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1377 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1378 })
1379}
1380
1381#[gpui::test]
1382async fn test_clean_up_diff(cx: &mut TestAppContext) {
1383 init_test(cx);
1384
1385 assert_eq!(
1386 apply_edit_prediction(
1387 indoc! {"
1388 fn main() {
1389 let word_1 = \"lorem\";
1390 let range = word.len()..word.len();
1391 }
1392 "},
1393 indoc! {"
1394 <|editable_region_start|>
1395 fn main() {
1396 let word_1 = \"lorem\";
1397 let range = word_1.len()..word_1.len();
1398 }
1399
1400 <|editable_region_end|>
1401 "},
1402 cx,
1403 )
1404 .await,
1405 indoc! {"
1406 fn main() {
1407 let word_1 = \"lorem\";
1408 let range = word_1.len()..word_1.len();
1409 }
1410 "},
1411 );
1412
1413 assert_eq!(
1414 apply_edit_prediction(
1415 indoc! {"
1416 fn main() {
1417 let story = \"the quick\"
1418 }
1419 "},
1420 indoc! {"
1421 <|editable_region_start|>
1422 fn main() {
1423 let story = \"the quick brown fox jumps over the lazy dog\";
1424 }
1425
1426 <|editable_region_end|>
1427 "},
1428 cx,
1429 )
1430 .await,
1431 indoc! {"
1432 fn main() {
1433 let story = \"the quick brown fox jumps over the lazy dog\";
1434 }
1435 "},
1436 );
1437}
1438
1439#[gpui::test]
1440async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1441 init_test(cx);
1442
1443 let buffer_content = "lorem\n";
1444 let completion_response = indoc! {"
1445 ```animals.js
1446 <|start_of_file|>
1447 <|editable_region_start|>
1448 lorem
1449 ipsum
1450 <|editable_region_end|>
1451 ```"};
1452
1453 assert_eq!(
1454 apply_edit_prediction(buffer_content, completion_response, cx).await,
1455 "lorem\nipsum"
1456 );
1457}
1458
1459#[gpui::test]
1460async fn test_can_collect_data(cx: &mut TestAppContext) {
1461 init_test(cx);
1462
1463 let fs = project::FakeFs::new(cx.executor());
1464 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1465 .await;
1466
1467 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1468 let buffer = project
1469 .update(cx, |project, cx| {
1470 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1471 })
1472 .await
1473 .unwrap();
1474
1475 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1476 ep_store.update(cx, |ep_store, _cx| {
1477 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1478 });
1479
1480 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1481 assert_eq!(
1482 captured_request.lock().clone().unwrap().can_collect_data,
1483 true
1484 );
1485
1486 ep_store.update(cx, |ep_store, _cx| {
1487 ep_store.data_collection_choice = DataCollectionChoice::Disabled
1488 });
1489
1490 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1491 assert_eq!(
1492 captured_request.lock().clone().unwrap().can_collect_data,
1493 false
1494 );
1495}
1496
1497#[gpui::test]
1498async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1499 init_test(cx);
1500
1501 let fs = project::FakeFs::new(cx.executor());
1502 let project = Project::test(fs.clone(), [], cx).await;
1503
1504 let buffer = cx.new(|_cx| {
1505 Buffer::remote(
1506 language::BufferId::new(1).unwrap(),
1507 ReplicaId::new(1),
1508 language::Capability::ReadWrite,
1509 "fn main() {\n println!(\"Hello\");\n}",
1510 )
1511 });
1512
1513 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1514 ep_store.update(cx, |ep_store, _cx| {
1515 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1516 });
1517
1518 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1519 assert_eq!(
1520 captured_request.lock().clone().unwrap().can_collect_data,
1521 false
1522 );
1523}
1524
1525#[gpui::test]
1526async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1527 init_test(cx);
1528
1529 let fs = project::FakeFs::new(cx.executor());
1530 fs.insert_tree(
1531 path!("/project"),
1532 json!({
1533 "LICENSE": BSD_0_TXT,
1534 ".env": "SECRET_KEY=secret"
1535 }),
1536 )
1537 .await;
1538
1539 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1540 let buffer = project
1541 .update(cx, |project, cx| {
1542 project.open_local_buffer("/project/.env", cx)
1543 })
1544 .await
1545 .unwrap();
1546
1547 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1548 ep_store.update(cx, |ep_store, _cx| {
1549 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1550 });
1551
1552 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1553 assert_eq!(
1554 captured_request.lock().clone().unwrap().can_collect_data,
1555 false
1556 );
1557}
1558
1559#[gpui::test]
1560async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1561 init_test(cx);
1562
1563 let fs = project::FakeFs::new(cx.executor());
1564 let project = Project::test(fs.clone(), [], cx).await;
1565 let buffer = cx.new(|cx| Buffer::local("", cx));
1566
1567 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1568 ep_store.update(cx, |ep_store, _cx| {
1569 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1570 });
1571
1572 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1573 assert_eq!(
1574 captured_request.lock().clone().unwrap().can_collect_data,
1575 false
1576 );
1577}
1578
1579#[gpui::test]
1580async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1581 init_test(cx);
1582
1583 let fs = project::FakeFs::new(cx.executor());
1584 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1585 .await;
1586
1587 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1588 let buffer = project
1589 .update(cx, |project, cx| {
1590 project.open_local_buffer("/project/main.rs", cx)
1591 })
1592 .await
1593 .unwrap();
1594
1595 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1596 ep_store.update(cx, |ep_store, _cx| {
1597 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1598 });
1599
1600 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1601 assert_eq!(
1602 captured_request.lock().clone().unwrap().can_collect_data,
1603 false
1604 );
1605}
1606
1607#[gpui::test]
1608async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1609 init_test(cx);
1610
1611 let fs = project::FakeFs::new(cx.executor());
1612 fs.insert_tree(
1613 path!("/open_source_worktree"),
1614 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1615 )
1616 .await;
1617 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1618 .await;
1619
1620 let project = Project::test(
1621 fs.clone(),
1622 [
1623 path!("/open_source_worktree").as_ref(),
1624 path!("/closed_source_worktree").as_ref(),
1625 ],
1626 cx,
1627 )
1628 .await;
1629 let buffer = project
1630 .update(cx, |project, cx| {
1631 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1632 })
1633 .await
1634 .unwrap();
1635
1636 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1637 ep_store.update(cx, |ep_store, _cx| {
1638 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1639 });
1640
1641 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1642 assert_eq!(
1643 captured_request.lock().clone().unwrap().can_collect_data,
1644 true
1645 );
1646
1647 let closed_source_file = project
1648 .update(cx, |project, cx| {
1649 let worktree2 = project
1650 .worktree_for_root_name("closed_source_worktree", cx)
1651 .unwrap();
1652 worktree2.update(cx, |worktree2, cx| {
1653 worktree2.load_file(rel_path("main.rs"), cx)
1654 })
1655 })
1656 .await
1657 .unwrap()
1658 .file;
1659
1660 buffer.update(cx, |buffer, cx| {
1661 buffer.file_updated(closed_source_file, cx);
1662 });
1663
1664 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1665 assert_eq!(
1666 captured_request.lock().clone().unwrap().can_collect_data,
1667 false
1668 );
1669}
1670
1671#[gpui::test]
1672async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1673 init_test(cx);
1674
1675 let fs = project::FakeFs::new(cx.executor());
1676 fs.insert_tree(
1677 path!("/worktree1"),
1678 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1679 )
1680 .await;
1681 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1682 .await;
1683
1684 let project = Project::test(
1685 fs.clone(),
1686 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1687 cx,
1688 )
1689 .await;
1690 let buffer = project
1691 .update(cx, |project, cx| {
1692 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1693 })
1694 .await
1695 .unwrap();
1696 let private_buffer = project
1697 .update(cx, |project, cx| {
1698 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1699 })
1700 .await
1701 .unwrap();
1702
1703 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1704 ep_store.update(cx, |ep_store, _cx| {
1705 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1706 });
1707
1708 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1709 assert_eq!(
1710 captured_request.lock().clone().unwrap().can_collect_data,
1711 true
1712 );
1713
1714 // this has a side effect of registering the buffer to watch for edits
1715 run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1716 assert_eq!(
1717 captured_request.lock().clone().unwrap().can_collect_data,
1718 false
1719 );
1720
1721 private_buffer.update(cx, |private_buffer, cx| {
1722 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1723 });
1724
1725 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1726 assert_eq!(
1727 captured_request.lock().clone().unwrap().can_collect_data,
1728 false
1729 );
1730
1731 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1732 // included
1733 buffer.update(cx, |buffer, cx| {
1734 buffer.edit(
1735 [(
1736 0..0,
1737 " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1738 )],
1739 None,
1740 cx,
1741 );
1742 });
1743
1744 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1745 assert_eq!(
1746 captured_request.lock().clone().unwrap().can_collect_data,
1747 true
1748 );
1749}
1750
1751fn init_test(cx: &mut TestAppContext) {
1752 cx.update(|cx| {
1753 let settings_store = SettingsStore::test(cx);
1754 cx.set_global(settings_store);
1755 });
1756}
1757
1758async fn apply_edit_prediction(
1759 buffer_content: &str,
1760 completion_response: &str,
1761 cx: &mut TestAppContext,
1762) -> String {
1763 let fs = project::FakeFs::new(cx.executor());
1764 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1765 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1766 let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1767 *response.lock() = completion_response.to_string();
1768 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1769 buffer.update(cx, |buffer, cx| {
1770 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1771 });
1772 buffer.read_with(cx, |buffer, _| buffer.text())
1773}
1774
1775async fn run_edit_prediction(
1776 buffer: &Entity<Buffer>,
1777 project: &Entity<Project>,
1778 ep_store: &Entity<EditPredictionStore>,
1779 cx: &mut TestAppContext,
1780) -> EditPrediction {
1781 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1782 ep_store.update(cx, |ep_store, cx| {
1783 ep_store.register_buffer(buffer, &project, cx)
1784 });
1785 cx.background_executor.run_until_parked();
1786 let prediction_task = ep_store.update(cx, |ep_store, cx| {
1787 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
1788 });
1789 prediction_task.await.unwrap().unwrap().prediction.unwrap()
1790}
1791
1792async fn make_test_ep_store(
1793 project: &Entity<Project>,
1794 cx: &mut TestAppContext,
1795) -> (
1796 Entity<EditPredictionStore>,
1797 Arc<Mutex<Option<PredictEditsBody>>>,
1798 Arc<Mutex<String>>,
1799) {
1800 let default_response = indoc! {"
1801 ```main.rs
1802 <|start_of_file|>
1803 <|editable_region_start|>
1804 hello world
1805 <|editable_region_end|>
1806 ```"
1807 };
1808 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
1809 let completion_response: Arc<Mutex<String>> =
1810 Arc::new(Mutex::new(default_response.to_string()));
1811 let http_client = FakeHttpClient::create({
1812 let captured_request = captured_request.clone();
1813 let completion_response = completion_response.clone();
1814 let mut next_request_id = 0;
1815 move |req| {
1816 let captured_request = captured_request.clone();
1817 let completion_response = completion_response.clone();
1818 async move {
1819 match (req.method(), req.uri().path()) {
1820 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1821 .status(200)
1822 .body(
1823 serde_json::to_string(&CreateLlmTokenResponse {
1824 token: LlmToken("the-llm-token".to_string()),
1825 })
1826 .unwrap()
1827 .into(),
1828 )
1829 .unwrap()),
1830 (&Method::POST, "/predict_edits/v2") => {
1831 let mut request_body = String::new();
1832 req.into_body().read_to_string(&mut request_body).await?;
1833 *captured_request.lock() =
1834 Some(serde_json::from_str(&request_body).unwrap());
1835 next_request_id += 1;
1836 Ok(http_client::Response::builder()
1837 .status(200)
1838 .body(
1839 serde_json::to_string(&PredictEditsResponse {
1840 request_id: format!("request-{next_request_id}"),
1841 output_excerpt: completion_response.lock().clone(),
1842 })
1843 .unwrap()
1844 .into(),
1845 )
1846 .unwrap())
1847 }
1848 _ => Ok(http_client::Response::builder()
1849 .status(404)
1850 .body("Not Found".into())
1851 .unwrap()),
1852 }
1853 }
1854 }
1855 });
1856
1857 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1858 cx.update(|cx| {
1859 RefreshLlmTokenListener::register(client.clone(), cx);
1860 });
1861 let _server = FakeServer::for_client(42, &client, cx).await;
1862
1863 let ep_store = cx.new(|cx| {
1864 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
1865 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
1866
1867 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
1868 for worktree in worktrees {
1869 let worktree_id = worktree.read(cx).id();
1870 ep_store
1871 .get_or_init_project(project, cx)
1872 .license_detection_watchers
1873 .entry(worktree_id)
1874 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
1875 }
1876
1877 ep_store
1878 });
1879
1880 (ep_store, captured_request, completion_response)
1881}
1882
1883fn to_completion_edits(
1884 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
1885 buffer: &Entity<Buffer>,
1886 cx: &App,
1887) -> Vec<(Range<Anchor>, Arc<str>)> {
1888 let buffer = buffer.read(cx);
1889 iterator
1890 .into_iter()
1891 .map(|(range, text)| {
1892 (
1893 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1894 text,
1895 )
1896 })
1897 .collect()
1898}
1899
1900fn from_completion_edits(
1901 editor_edits: &[(Range<Anchor>, Arc<str>)],
1902 buffer: &Entity<Buffer>,
1903 cx: &App,
1904) -> Vec<(Range<usize>, Arc<str>)> {
1905 let buffer = buffer.read(cx);
1906 editor_edits
1907 .iter()
1908 .map(|(range, text)| {
1909 (
1910 range.start.to_offset(buffer)..range.end.to_offset(buffer),
1911 text.clone(),
1912 )
1913 })
1914 .collect()
1915}
1916
1917#[gpui::test]
1918async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
1919 init_test(cx);
1920
1921 let fs = FakeFs::new(cx.executor());
1922 fs.insert_tree(
1923 "/project",
1924 serde_json::json!({
1925 "main.rs": "fn main() {\n \n}\n"
1926 }),
1927 )
1928 .await;
1929
1930 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1931
1932 let http_client = FakeHttpClient::create(|_req| async move {
1933 Ok(gpui::http_client::Response::builder()
1934 .status(401)
1935 .body("Unauthorized".into())
1936 .unwrap())
1937 });
1938
1939 let client =
1940 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1941 cx.update(|cx| {
1942 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
1943 });
1944
1945 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
1946
1947 let buffer = project
1948 .update(cx, |project, cx| {
1949 let path = project
1950 .find_project_path(path!("/project/main.rs"), cx)
1951 .unwrap();
1952 project.open_buffer(path, cx)
1953 })
1954 .await
1955 .unwrap();
1956
1957 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
1958 ep_store.update(cx, |ep_store, cx| {
1959 ep_store.register_buffer(&buffer, &project, cx)
1960 });
1961 cx.background_executor.run_until_parked();
1962
1963 let completion_task = ep_store.update(cx, |ep_store, cx| {
1964 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
1965 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
1966 });
1967
1968 let result = completion_task.await;
1969 assert!(
1970 result.is_err(),
1971 "Without authentication and without custom URL, prediction should fail"
1972 );
1973}
1974
1975#[gpui::test]
1976async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
1977 init_test(cx);
1978
1979 let fs = FakeFs::new(cx.executor());
1980 fs.insert_tree(
1981 "/project",
1982 serde_json::json!({
1983 "main.rs": "fn main() {\n \n}\n"
1984 }),
1985 )
1986 .await;
1987
1988 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1989
1990 let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
1991 let predict_called_clone = predict_called.clone();
1992
1993 let http_client = FakeHttpClient::create({
1994 move |req| {
1995 let uri = req.uri().path().to_string();
1996 let predict_called = predict_called_clone.clone();
1997 async move {
1998 if uri.contains("predict") {
1999 predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2000 Ok(gpui::http_client::Response::builder()
2001 .body(
2002 serde_json::to_string(&open_ai::Response {
2003 id: "test-123".to_string(),
2004 object: "chat.completion".to_string(),
2005 created: 0,
2006 model: "test".to_string(),
2007 usage: open_ai::Usage {
2008 prompt_tokens: 0,
2009 completion_tokens: 0,
2010 total_tokens: 0,
2011 },
2012 choices: vec![open_ai::Choice {
2013 index: 0,
2014 message: open_ai::RequestMessage::Assistant {
2015 content: Some(open_ai::MessageContent::Plain(
2016 indoc! {"
2017 ```main.rs
2018 <|start_of_file|>
2019 <|editable_region_start|>
2020 fn main() {
2021 println!(\"Hello, world!\");
2022 }
2023 <|editable_region_end|>
2024 ```
2025 "}
2026 .to_string(),
2027 )),
2028 tool_calls: vec![],
2029 },
2030 finish_reason: Some("stop".to_string()),
2031 }],
2032 })
2033 .unwrap()
2034 .into(),
2035 )
2036 .unwrap())
2037 } else {
2038 Ok(gpui::http_client::Response::builder()
2039 .status(401)
2040 .body("Unauthorized".into())
2041 .unwrap())
2042 }
2043 }
2044 }
2045 });
2046
2047 let client =
2048 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2049 cx.update(|cx| {
2050 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2051 });
2052
2053 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2054
2055 let buffer = project
2056 .update(cx, |project, cx| {
2057 let path = project
2058 .find_project_path(path!("/project/main.rs"), cx)
2059 .unwrap();
2060 project.open_buffer(path, cx)
2061 })
2062 .await
2063 .unwrap();
2064
2065 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2066 ep_store.update(cx, |ep_store, cx| {
2067 ep_store.register_buffer(&buffer, &project, cx)
2068 });
2069 cx.background_executor.run_until_parked();
2070
2071 let completion_task = ep_store.update(cx, |ep_store, cx| {
2072 ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2073 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2074 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2075 });
2076
2077 let _ = completion_task.await;
2078
2079 assert!(
2080 predict_called.load(std::sync::atomic::Ordering::SeqCst),
2081 "With custom URL, predict endpoint should be called even without authentication"
2082 );
2083}
2084
2085#[gpui::test]
2086fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2087 let buffer = cx.new(|cx| {
2088 Buffer::local(
2089 indoc! {"
2090 zero
2091 one
2092 two
2093 three
2094 four
2095 five
2096 six
2097 seven
2098 eight
2099 nine
2100 ten
2101 eleven
2102 twelve
2103 thirteen
2104 fourteen
2105 fifteen
2106 sixteen
2107 seventeen
2108 eighteen
2109 nineteen
2110 twenty
2111 twenty-one
2112 twenty-two
2113 twenty-three
2114 twenty-four
2115 "},
2116 cx,
2117 )
2118 });
2119
2120 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2121
2122 buffer.update(cx, |buffer, cx| {
2123 let point = Point::new(12, 0);
2124 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2125 let point = Point::new(8, 0);
2126 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2127 });
2128
2129 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2130
2131 let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2132
2133 assert_eq!(
2134 diff,
2135 indoc! {"
2136 @@ -6,10 +6,12 @@
2137 five
2138 six
2139 seven
2140 +FIRST INSERTION
2141 eight
2142 nine
2143 ten
2144 eleven
2145 +SECOND INSERTION
2146 twelve
2147 thirteen
2148 fourteen
2149 "}
2150 );
2151}
2152
2153#[ctor::ctor]
2154fn init_logger() {
2155 zlog::init_test();
2156}