1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
3use client::{UserStore, test::FakeServer};
4use clock::{FakeSystemClock, ReplicaId};
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
8 RejectEditPredictionsBody,
9 predict_edits_v3::{
10 RawCompletionChoice, RawCompletionRequest, RawCompletionResponse, RawCompletionUsage,
11 },
12};
13use futures::{
14 AsyncReadExt, StreamExt,
15 channel::{mpsc, oneshot},
16};
17use gpui::{
18 Entity, TestAppContext,
19 http_client::{FakeHttpClient, Response},
20};
21use indoc::indoc;
22use language::Point;
23use lsp::LanguageServerId;
24use parking_lot::Mutex;
25use pretty_assertions::{assert_eq, assert_matches};
26use project::{FakeFs, Project};
27use serde_json::json;
28use settings::SettingsStore;
29use std::{path::Path, sync::Arc, time::Duration};
30use util::{path, rel_path::rel_path};
31use uuid::Uuid;
32use zeta_prompt::ZetaPromptInput;
33
34use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
35
36#[gpui::test]
37async fn test_current_state(cx: &mut TestAppContext) {
38 let (ep_store, mut requests) = init_test_with_fake_client(cx);
39 let fs = FakeFs::new(cx.executor());
40 fs.insert_tree(
41 "/root",
42 json!({
43 "1.txt": "Hello!\nHow\nBye\n",
44 "2.txt": "Hola!\nComo\nAdios\n"
45 }),
46 )
47 .await;
48 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
49
50 let buffer1 = project
51 .update(cx, |project, cx| {
52 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
53 project.set_active_path(Some(path.clone()), cx);
54 project.open_buffer(path, cx)
55 })
56 .await
57 .unwrap();
58 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
59 let position = snapshot1.anchor_before(language::Point::new(1, 3));
60
61 ep_store.update(cx, |ep_store, cx| {
62 ep_store.register_project(&project, cx);
63 ep_store.register_buffer(&buffer1, &project, cx);
64 });
65
66 // Prediction for current file
67
68 ep_store.update(cx, |ep_store, cx| {
69 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
70 });
71 let (request, respond_tx) = requests.predict.next().await.unwrap();
72
73 respond_tx
74 .send(model_response(
75 request,
76 indoc! {r"
77 --- a/root/1.txt
78 +++ b/root/1.txt
79 @@ ... @@
80 Hello!
81 -How
82 +How are you?
83 Bye
84 "},
85 ))
86 .unwrap();
87
88 cx.run_until_parked();
89
90 ep_store.update(cx, |ep_store, cx| {
91 let prediction = ep_store
92 .prediction_at(&buffer1, None, &project, cx)
93 .unwrap();
94 assert_matches!(prediction, BufferEditPrediction::Local { .. });
95 });
96
97 ep_store.update(cx, |ep_store, _cx| {
98 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
99 });
100
101 // Prediction for diagnostic in another file
102
103 let diagnostic = lsp::Diagnostic {
104 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
105 severity: Some(lsp::DiagnosticSeverity::ERROR),
106 message: "Sentence is incomplete".to_string(),
107 ..Default::default()
108 };
109
110 project.update(cx, |project, cx| {
111 project.lsp_store().update(cx, |lsp_store, cx| {
112 lsp_store
113 .update_diagnostics(
114 LanguageServerId(0),
115 lsp::PublishDiagnosticsParams {
116 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
117 diagnostics: vec![diagnostic],
118 version: None,
119 },
120 None,
121 language::DiagnosticSourceKind::Pushed,
122 &[],
123 cx,
124 )
125 .unwrap();
126 });
127 });
128
129 let (request, respond_tx) = requests.predict.next().await.unwrap();
130 respond_tx
131 .send(model_response(
132 request,
133 indoc! {r#"
134 --- a/root/2.txt
135 +++ b/root/2.txt
136 @@ ... @@
137 Hola!
138 -Como
139 +Como estas?
140 Adios
141 "#},
142 ))
143 .unwrap();
144 cx.run_until_parked();
145
146 ep_store.update(cx, |ep_store, cx| {
147 let prediction = ep_store
148 .prediction_at(&buffer1, None, &project, cx)
149 .unwrap();
150 assert_matches!(
151 prediction,
152 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
153 );
154 });
155
156 let buffer2 = project
157 .update(cx, |project, cx| {
158 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
159 project.open_buffer(path, cx)
160 })
161 .await
162 .unwrap();
163
164 ep_store.update(cx, |ep_store, cx| {
165 let prediction = ep_store
166 .prediction_at(&buffer2, None, &project, cx)
167 .unwrap();
168 assert_matches!(prediction, BufferEditPrediction::Local { .. });
169 });
170}
171
172#[gpui::test]
173async fn test_simple_request(cx: &mut TestAppContext) {
174 let (ep_store, mut requests) = init_test_with_fake_client(cx);
175 let fs = FakeFs::new(cx.executor());
176 fs.insert_tree(
177 "/root",
178 json!({
179 "foo.md": "Hello!\nHow\nBye\n"
180 }),
181 )
182 .await;
183 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
184
185 let buffer = project
186 .update(cx, |project, cx| {
187 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
188 project.open_buffer(path, cx)
189 })
190 .await
191 .unwrap();
192 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
193 let position = snapshot.anchor_before(language::Point::new(1, 3));
194
195 let prediction_task = ep_store.update(cx, |ep_store, cx| {
196 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
197 });
198
199 let (request, respond_tx) = requests.predict.next().await.unwrap();
200
201 // TODO Put back when we have a structured request again
202 // assert_eq!(
203 // request.excerpt_path.as_ref(),
204 // Path::new(path!("root/foo.md"))
205 // );
206 // assert_eq!(
207 // request.cursor_point,
208 // Point {
209 // line: Line(1),
210 // column: 3
211 // }
212 // );
213
214 respond_tx
215 .send(model_response(
216 request,
217 indoc! { r"
218 --- a/root/foo.md
219 +++ b/root/foo.md
220 @@ ... @@
221 Hello!
222 -How
223 +How are you?
224 Bye
225 "},
226 ))
227 .unwrap();
228
229 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
230
231 assert_eq!(prediction.edits.len(), 1);
232 assert_eq!(
233 prediction.edits[0].0.to_point(&snapshot).start,
234 language::Point::new(1, 3)
235 );
236 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
237}
238
239#[gpui::test]
240async fn test_request_events(cx: &mut TestAppContext) {
241 let (ep_store, mut requests) = init_test_with_fake_client(cx);
242 let fs = FakeFs::new(cx.executor());
243 fs.insert_tree(
244 "/root",
245 json!({
246 "foo.md": "Hello!\n\nBye\n"
247 }),
248 )
249 .await;
250 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
251
252 let buffer = project
253 .update(cx, |project, cx| {
254 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
255 project.open_buffer(path, cx)
256 })
257 .await
258 .unwrap();
259
260 ep_store.update(cx, |ep_store, cx| {
261 ep_store.register_buffer(&buffer, &project, cx);
262 });
263
264 buffer.update(cx, |buffer, cx| {
265 buffer.edit(vec![(7..7, "How")], None, cx);
266 });
267
268 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
269 let position = snapshot.anchor_before(language::Point::new(1, 3));
270
271 let prediction_task = ep_store.update(cx, |ep_store, cx| {
272 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
273 });
274
275 let (request, respond_tx) = requests.predict.next().await.unwrap();
276
277 let prompt = prompt_from_request(&request);
278 assert!(
279 prompt.contains(indoc! {"
280 --- a/root/foo.md
281 +++ b/root/foo.md
282 @@ -1,3 +1,3 @@
283 Hello!
284 -
285 +How
286 Bye
287 "}),
288 "{prompt}"
289 );
290
291 respond_tx
292 .send(model_response(
293 request,
294 indoc! {r#"
295 --- a/root/foo.md
296 +++ b/root/foo.md
297 @@ ... @@
298 Hello!
299 -How
300 +How are you?
301 Bye
302 "#},
303 ))
304 .unwrap();
305
306 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
307
308 assert_eq!(prediction.edits.len(), 1);
309 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
310}
311
312#[gpui::test]
313async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
314 let (ep_store, _requests) = init_test_with_fake_client(cx);
315 let fs = FakeFs::new(cx.executor());
316 fs.insert_tree(
317 "/root",
318 json!({
319 "foo.md": "Hello!\n\nBye\n"
320 }),
321 )
322 .await;
323 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
324
325 let buffer = project
326 .update(cx, |project, cx| {
327 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
328 project.open_buffer(path, cx)
329 })
330 .await
331 .unwrap();
332
333 ep_store.update(cx, |ep_store, cx| {
334 ep_store.register_buffer(&buffer, &project, cx);
335 });
336
337 // First burst: insert "How"
338 buffer.update(cx, |buffer, cx| {
339 buffer.edit(vec![(7..7, "How")], None, cx);
340 });
341
342 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
343 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
344 cx.run_until_parked();
345
346 // Second burst: append " are you?" immediately after "How" on the same line.
347 //
348 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
349 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
350 buffer.update(cx, |buffer, cx| {
351 buffer.edit(vec![(10..10, " are you?")], None, cx);
352 });
353
354 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
355 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
356 buffer.update(cx, |buffer, cx| {
357 buffer.edit(vec![(19..19, "!")], None, cx);
358 });
359
360 // Without time-based splitting, there is one event.
361 let events = ep_store.update(cx, |ep_store, cx| {
362 ep_store.edit_history_for_project(&project, cx)
363 });
364 assert_eq!(events.len(), 1);
365 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
366 assert_eq!(
367 diff.as_str(),
368 indoc! {"
369 @@ -1,3 +1,3 @@
370 Hello!
371 -
372 +How are you?!
373 Bye
374 "}
375 );
376
377 // With time-based splitting, there are two distinct events.
378 let events = ep_store.update(cx, |ep_store, cx| {
379 ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
380 });
381 assert_eq!(events.len(), 2);
382 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
383 assert_eq!(
384 diff.as_str(),
385 indoc! {"
386 @@ -1,3 +1,3 @@
387 Hello!
388 -
389 +How
390 Bye
391 "}
392 );
393
394 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
395 assert_eq!(
396 diff.as_str(),
397 indoc! {"
398 @@ -1,3 +1,3 @@
399 Hello!
400 -How
401 +How are you?!
402 Bye
403 "}
404 );
405}
406
407#[gpui::test]
408async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
409 let (ep_store, _requests) = init_test_with_fake_client(cx);
410 let fs = FakeFs::new(cx.executor());
411
412 // Create a file with 30 lines to test line-based coalescing
413 let content = (1..=30)
414 .map(|i| format!("Line {}\n", i))
415 .collect::<String>();
416 fs.insert_tree(
417 "/root",
418 json!({
419 "foo.md": content
420 }),
421 )
422 .await;
423 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
424
425 let buffer = project
426 .update(cx, |project, cx| {
427 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
428 project.open_buffer(path, cx)
429 })
430 .await
431 .unwrap();
432
433 ep_store.update(cx, |ep_store, cx| {
434 ep_store.register_buffer(&buffer, &project, cx);
435 });
436
437 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
438 buffer.update(cx, |buffer, cx| {
439 let start = Point::new(10, 0).to_offset(buffer);
440 let end = Point::new(13, 0).to_offset(buffer);
441 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
442 });
443
444 let events = ep_store.update(cx, |ep_store, cx| {
445 ep_store.edit_history_for_project(&project, cx)
446 });
447 assert_eq!(
448 render_events(&events),
449 indoc! {"
450 @@ -8,9 +8,8 @@
451 Line 8
452 Line 9
453 Line 10
454 -Line 11
455 -Line 12
456 -Line 13
457 +Middle A
458 +Middle B
459 Line 14
460 Line 15
461 Line 16
462 "},
463 "After first edit"
464 );
465
466 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
467 // This tests that coalescing considers the START of the existing range
468 buffer.update(cx, |buffer, cx| {
469 let offset = Point::new(5, 0).to_offset(buffer);
470 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
471 });
472
473 let events = ep_store.update(cx, |ep_store, cx| {
474 ep_store.edit_history_for_project(&project, cx)
475 });
476 assert_eq!(
477 render_events(&events),
478 indoc! {"
479 @@ -3,14 +3,14 @@
480 Line 3
481 Line 4
482 Line 5
483 +Above
484 Line 6
485 Line 7
486 Line 8
487 Line 9
488 Line 10
489 -Line 11
490 -Line 12
491 -Line 13
492 +Middle A
493 +Middle B
494 Line 14
495 Line 15
496 Line 16
497 "},
498 "After inserting above (should coalesce)"
499 );
500
501 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
502 // This tests that coalescing considers the END of the existing range
503 buffer.update(cx, |buffer, cx| {
504 let offset = Point::new(14, 0).to_offset(buffer);
505 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
506 });
507
508 let events = ep_store.update(cx, |ep_store, cx| {
509 ep_store.edit_history_for_project(&project, cx)
510 });
511 assert_eq!(
512 render_events(&events),
513 indoc! {"
514 @@ -3,15 +3,16 @@
515 Line 3
516 Line 4
517 Line 5
518 +Above
519 Line 6
520 Line 7
521 Line 8
522 Line 9
523 Line 10
524 -Line 11
525 -Line 12
526 -Line 13
527 +Middle A
528 +Middle B
529 Line 14
530 +Below
531 Line 15
532 Line 16
533 Line 17
534 "},
535 "After inserting below (should coalesce)"
536 );
537
538 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
539 // This should NOT coalesce - creates a new event
540 buffer.update(cx, |buffer, cx| {
541 let offset = Point::new(25, 0).to_offset(buffer);
542 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
543 });
544
545 let events = ep_store.update(cx, |ep_store, cx| {
546 ep_store.edit_history_for_project(&project, cx)
547 });
548 assert_eq!(
549 render_events(&events),
550 indoc! {"
551 @@ -3,15 +3,16 @@
552 Line 3
553 Line 4
554 Line 5
555 +Above
556 Line 6
557 Line 7
558 Line 8
559 Line 9
560 Line 10
561 -Line 11
562 -Line 12
563 -Line 13
564 +Middle A
565 +Middle B
566 Line 14
567 +Below
568 Line 15
569 Line 16
570 Line 17
571
572 ---
573 @@ -23,6 +23,7 @@
574 Line 22
575 Line 23
576 Line 24
577 +Far below
578 Line 25
579 Line 26
580 Line 27
581 "},
582 "After inserting far below (should NOT coalesce)"
583 );
584}
585
586fn render_events(events: &[StoredEvent]) -> String {
587 events
588 .iter()
589 .map(|e| {
590 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
591 diff.as_str()
592 })
593 .collect::<Vec<_>>()
594 .join("\n---\n")
595}
596
597#[gpui::test]
598async fn test_empty_prediction(cx: &mut TestAppContext) {
599 let (ep_store, mut requests) = init_test_with_fake_client(cx);
600 let fs = FakeFs::new(cx.executor());
601 fs.insert_tree(
602 "/root",
603 json!({
604 "foo.md": "Hello!\nHow\nBye\n"
605 }),
606 )
607 .await;
608 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
609
610 let buffer = project
611 .update(cx, |project, cx| {
612 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
613 project.open_buffer(path, cx)
614 })
615 .await
616 .unwrap();
617 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
618 let position = snapshot.anchor_before(language::Point::new(1, 3));
619
620 ep_store.update(cx, |ep_store, cx| {
621 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
622 });
623
624 let (request, respond_tx) = requests.predict.next().await.unwrap();
625 let response = model_response(request, "");
626 let id = response.id.clone();
627 respond_tx.send(response).unwrap();
628
629 cx.run_until_parked();
630
631 ep_store.update(cx, |ep_store, cx| {
632 assert!(
633 ep_store
634 .prediction_at(&buffer, None, &project, cx)
635 .is_none()
636 );
637 });
638
639 // prediction is reported as rejected
640 let (reject_request, _) = requests.reject.next().await.unwrap();
641
642 assert_eq!(
643 &reject_request.rejections,
644 &[EditPredictionRejection {
645 request_id: id,
646 reason: EditPredictionRejectReason::Empty,
647 was_shown: false
648 }]
649 );
650}
651
652#[gpui::test]
653async fn test_interpolated_empty(cx: &mut TestAppContext) {
654 let (ep_store, mut requests) = init_test_with_fake_client(cx);
655 let fs = FakeFs::new(cx.executor());
656 fs.insert_tree(
657 "/root",
658 json!({
659 "foo.md": "Hello!\nHow\nBye\n"
660 }),
661 )
662 .await;
663 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
664
665 let buffer = project
666 .update(cx, |project, cx| {
667 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
668 project.open_buffer(path, cx)
669 })
670 .await
671 .unwrap();
672 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
673 let position = snapshot.anchor_before(language::Point::new(1, 3));
674
675 ep_store.update(cx, |ep_store, cx| {
676 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
677 });
678
679 let (request, respond_tx) = requests.predict.next().await.unwrap();
680
681 buffer.update(cx, |buffer, cx| {
682 buffer.set_text("Hello!\nHow are you?\nBye", cx);
683 });
684
685 let response = model_response(request, SIMPLE_DIFF);
686 let id = response.id.clone();
687 respond_tx.send(response).unwrap();
688
689 cx.run_until_parked();
690
691 ep_store.update(cx, |ep_store, cx| {
692 assert!(
693 ep_store
694 .prediction_at(&buffer, None, &project, cx)
695 .is_none()
696 );
697 });
698
699 // prediction is reported as rejected
700 let (reject_request, _) = requests.reject.next().await.unwrap();
701
702 assert_eq!(
703 &reject_request.rejections,
704 &[EditPredictionRejection {
705 request_id: id,
706 reason: EditPredictionRejectReason::InterpolatedEmpty,
707 was_shown: false
708 }]
709 );
710}
711
712const SIMPLE_DIFF: &str = indoc! { r"
713 --- a/root/foo.md
714 +++ b/root/foo.md
715 @@ ... @@
716 Hello!
717 -How
718 +How are you?
719 Bye
720"};
721
722#[gpui::test]
723async fn test_replace_current(cx: &mut TestAppContext) {
724 let (ep_store, mut requests) = init_test_with_fake_client(cx);
725 let fs = FakeFs::new(cx.executor());
726 fs.insert_tree(
727 "/root",
728 json!({
729 "foo.md": "Hello!\nHow\nBye\n"
730 }),
731 )
732 .await;
733 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
734
735 let buffer = project
736 .update(cx, |project, cx| {
737 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
738 project.open_buffer(path, cx)
739 })
740 .await
741 .unwrap();
742 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
743 let position = snapshot.anchor_before(language::Point::new(1, 3));
744
745 ep_store.update(cx, |ep_store, cx| {
746 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
747 });
748
749 let (request, respond_tx) = requests.predict.next().await.unwrap();
750 let first_response = model_response(request, SIMPLE_DIFF);
751 let first_id = first_response.id.clone();
752 respond_tx.send(first_response).unwrap();
753
754 cx.run_until_parked();
755
756 ep_store.update(cx, |ep_store, cx| {
757 assert_eq!(
758 ep_store
759 .prediction_at(&buffer, None, &project, cx)
760 .unwrap()
761 .id
762 .0,
763 first_id
764 );
765 });
766
767 // a second request is triggered
768 ep_store.update(cx, |ep_store, cx| {
769 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
770 });
771
772 let (request, respond_tx) = requests.predict.next().await.unwrap();
773 let second_response = model_response(request, SIMPLE_DIFF);
774 let second_id = second_response.id.clone();
775 respond_tx.send(second_response).unwrap();
776
777 cx.run_until_parked();
778
779 ep_store.update(cx, |ep_store, cx| {
780 // second replaces first
781 assert_eq!(
782 ep_store
783 .prediction_at(&buffer, None, &project, cx)
784 .unwrap()
785 .id
786 .0,
787 second_id
788 );
789 });
790
791 // first is reported as replaced
792 let (reject_request, _) = requests.reject.next().await.unwrap();
793
794 assert_eq!(
795 &reject_request.rejections,
796 &[EditPredictionRejection {
797 request_id: first_id,
798 reason: EditPredictionRejectReason::Replaced,
799 was_shown: false
800 }]
801 );
802}
803
804#[gpui::test]
805async fn test_current_preferred(cx: &mut TestAppContext) {
806 let (ep_store, mut requests) = init_test_with_fake_client(cx);
807 let fs = FakeFs::new(cx.executor());
808 fs.insert_tree(
809 "/root",
810 json!({
811 "foo.md": "Hello!\nHow\nBye\n"
812 }),
813 )
814 .await;
815 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
816
817 let buffer = project
818 .update(cx, |project, cx| {
819 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
820 project.open_buffer(path, cx)
821 })
822 .await
823 .unwrap();
824 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
825 let position = snapshot.anchor_before(language::Point::new(1, 3));
826
827 ep_store.update(cx, |ep_store, cx| {
828 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
829 });
830
831 let (request, respond_tx) = requests.predict.next().await.unwrap();
832 let first_response = model_response(request, SIMPLE_DIFF);
833 let first_id = first_response.id.clone();
834 respond_tx.send(first_response).unwrap();
835
836 cx.run_until_parked();
837
838 ep_store.update(cx, |ep_store, cx| {
839 assert_eq!(
840 ep_store
841 .prediction_at(&buffer, None, &project, cx)
842 .unwrap()
843 .id
844 .0,
845 first_id
846 );
847 });
848
849 // a second request is triggered
850 ep_store.update(cx, |ep_store, cx| {
851 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
852 });
853
854 let (request, respond_tx) = requests.predict.next().await.unwrap();
855 // worse than current prediction
856 let second_response = model_response(
857 request,
858 indoc! { r"
859 --- a/root/foo.md
860 +++ b/root/foo.md
861 @@ ... @@
862 Hello!
863 -How
864 +How are
865 Bye
866 "},
867 );
868 let second_id = second_response.id.clone();
869 respond_tx.send(second_response).unwrap();
870
871 cx.run_until_parked();
872
873 ep_store.update(cx, |ep_store, cx| {
874 // first is preferred over second
875 assert_eq!(
876 ep_store
877 .prediction_at(&buffer, None, &project, cx)
878 .unwrap()
879 .id
880 .0,
881 first_id
882 );
883 });
884
885 // second is reported as rejected
886 let (reject_request, _) = requests.reject.next().await.unwrap();
887
888 assert_eq!(
889 &reject_request.rejections,
890 &[EditPredictionRejection {
891 request_id: second_id,
892 reason: EditPredictionRejectReason::CurrentPreferred,
893 was_shown: false
894 }]
895 );
896}
897
898#[gpui::test]
899async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
900 let (ep_store, mut requests) = init_test_with_fake_client(cx);
901 let fs = FakeFs::new(cx.executor());
902 fs.insert_tree(
903 "/root",
904 json!({
905 "foo.md": "Hello!\nHow\nBye\n"
906 }),
907 )
908 .await;
909 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
910
911 let buffer = project
912 .update(cx, |project, cx| {
913 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
914 project.open_buffer(path, cx)
915 })
916 .await
917 .unwrap();
918 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
919 let position = snapshot.anchor_before(language::Point::new(1, 3));
920
921 // start two refresh tasks
922 ep_store.update(cx, |ep_store, cx| {
923 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
924 });
925
926 let (request1, respond_first) = requests.predict.next().await.unwrap();
927
928 ep_store.update(cx, |ep_store, cx| {
929 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
930 });
931
932 let (request, respond_second) = requests.predict.next().await.unwrap();
933
934 // wait for throttle
935 cx.run_until_parked();
936
937 // second responds first
938 let second_response = model_response(request, SIMPLE_DIFF);
939 let second_id = second_response.id.clone();
940 respond_second.send(second_response).unwrap();
941
942 cx.run_until_parked();
943
944 ep_store.update(cx, |ep_store, cx| {
945 // current prediction is second
946 assert_eq!(
947 ep_store
948 .prediction_at(&buffer, None, &project, cx)
949 .unwrap()
950 .id
951 .0,
952 second_id
953 );
954 });
955
956 let first_response = model_response(request1, SIMPLE_DIFF);
957 let first_id = first_response.id.clone();
958 respond_first.send(first_response).unwrap();
959
960 cx.run_until_parked();
961
962 ep_store.update(cx, |ep_store, cx| {
963 // current prediction is still second, since first was cancelled
964 assert_eq!(
965 ep_store
966 .prediction_at(&buffer, None, &project, cx)
967 .unwrap()
968 .id
969 .0,
970 second_id
971 );
972 });
973
974 // first is reported as rejected
975 let (reject_request, _) = requests.reject.next().await.unwrap();
976
977 cx.run_until_parked();
978
979 assert_eq!(
980 &reject_request.rejections,
981 &[EditPredictionRejection {
982 request_id: first_id,
983 reason: EditPredictionRejectReason::Canceled,
984 was_shown: false
985 }]
986 );
987}
988
989#[gpui::test]
990async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
991 let (ep_store, mut requests) = init_test_with_fake_client(cx);
992 let fs = FakeFs::new(cx.executor());
993 fs.insert_tree(
994 "/root",
995 json!({
996 "foo.md": "Hello!\nHow\nBye\n"
997 }),
998 )
999 .await;
1000 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1001
1002 let buffer = project
1003 .update(cx, |project, cx| {
1004 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1005 project.open_buffer(path, cx)
1006 })
1007 .await
1008 .unwrap();
1009 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1010 let position = snapshot.anchor_before(language::Point::new(1, 3));
1011
1012 // start two refresh tasks
1013 ep_store.update(cx, |ep_store, cx| {
1014 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1015 });
1016
1017 let (request1, respond_first) = requests.predict.next().await.unwrap();
1018
1019 ep_store.update(cx, |ep_store, cx| {
1020 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1021 });
1022
1023 let (request2, respond_second) = requests.predict.next().await.unwrap();
1024
1025 // wait for throttle, so requests are sent
1026 cx.run_until_parked();
1027
1028 ep_store.update(cx, |ep_store, cx| {
1029 // start a third request
1030 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1031
1032 // 2 are pending, so 2nd is cancelled
1033 assert_eq!(
1034 ep_store
1035 .get_or_init_project(&project, cx)
1036 .cancelled_predictions
1037 .iter()
1038 .copied()
1039 .collect::<Vec<_>>(),
1040 [1]
1041 );
1042 });
1043
1044 // wait for throttle
1045 cx.run_until_parked();
1046
1047 let (request3, respond_third) = requests.predict.next().await.unwrap();
1048
1049 let first_response = model_response(request1, SIMPLE_DIFF);
1050 let first_id = first_response.id.clone();
1051 respond_first.send(first_response).unwrap();
1052
1053 cx.run_until_parked();
1054
1055 ep_store.update(cx, |ep_store, cx| {
1056 // current prediction is first
1057 assert_eq!(
1058 ep_store
1059 .prediction_at(&buffer, None, &project, cx)
1060 .unwrap()
1061 .id
1062 .0,
1063 first_id
1064 );
1065 });
1066
1067 let cancelled_response = model_response(request2, SIMPLE_DIFF);
1068 let cancelled_id = cancelled_response.id.clone();
1069 respond_second.send(cancelled_response).unwrap();
1070
1071 cx.run_until_parked();
1072
1073 ep_store.update(cx, |ep_store, cx| {
1074 // current prediction is still first, since second was cancelled
1075 assert_eq!(
1076 ep_store
1077 .prediction_at(&buffer, None, &project, cx)
1078 .unwrap()
1079 .id
1080 .0,
1081 first_id
1082 );
1083 });
1084
1085 let third_response = model_response(request3, SIMPLE_DIFF);
1086 let third_response_id = third_response.id.clone();
1087 respond_third.send(third_response).unwrap();
1088
1089 cx.run_until_parked();
1090
1091 ep_store.update(cx, |ep_store, cx| {
1092 // third completes and replaces first
1093 assert_eq!(
1094 ep_store
1095 .prediction_at(&buffer, None, &project, cx)
1096 .unwrap()
1097 .id
1098 .0,
1099 third_response_id
1100 );
1101 });
1102
1103 // second is reported as rejected
1104 let (reject_request, _) = requests.reject.next().await.unwrap();
1105
1106 cx.run_until_parked();
1107
1108 assert_eq!(
1109 &reject_request.rejections,
1110 &[
1111 EditPredictionRejection {
1112 request_id: cancelled_id,
1113 reason: EditPredictionRejectReason::Canceled,
1114 was_shown: false
1115 },
1116 EditPredictionRejection {
1117 request_id: first_id,
1118 reason: EditPredictionRejectReason::Replaced,
1119 was_shown: false
1120 }
1121 ]
1122 );
1123}
1124
1125#[gpui::test]
1126async fn test_rejections_flushing(cx: &mut TestAppContext) {
1127 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1128
1129 ep_store.update(cx, |ep_store, _cx| {
1130 ep_store.reject_prediction(
1131 EditPredictionId("test-1".into()),
1132 EditPredictionRejectReason::Discarded,
1133 false,
1134 );
1135 ep_store.reject_prediction(
1136 EditPredictionId("test-2".into()),
1137 EditPredictionRejectReason::Canceled,
1138 true,
1139 );
1140 });
1141
1142 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1143 cx.run_until_parked();
1144
1145 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1146 respond_tx.send(()).unwrap();
1147
1148 // batched
1149 assert_eq!(reject_request.rejections.len(), 2);
1150 assert_eq!(
1151 reject_request.rejections[0],
1152 EditPredictionRejection {
1153 request_id: "test-1".to_string(),
1154 reason: EditPredictionRejectReason::Discarded,
1155 was_shown: false
1156 }
1157 );
1158 assert_eq!(
1159 reject_request.rejections[1],
1160 EditPredictionRejection {
1161 request_id: "test-2".to_string(),
1162 reason: EditPredictionRejectReason::Canceled,
1163 was_shown: true
1164 }
1165 );
1166
1167 // Reaching batch size limit sends without debounce
1168 ep_store.update(cx, |ep_store, _cx| {
1169 for i in 0..70 {
1170 ep_store.reject_prediction(
1171 EditPredictionId(format!("batch-{}", i).into()),
1172 EditPredictionRejectReason::Discarded,
1173 false,
1174 );
1175 }
1176 });
1177
1178 // First MAX/2 items are sent immediately
1179 cx.run_until_parked();
1180 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1181 respond_tx.send(()).unwrap();
1182
1183 assert_eq!(reject_request.rejections.len(), 50);
1184 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1185 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1186
1187 // Remaining items are debounced with the next batch
1188 cx.executor().advance_clock(Duration::from_secs(15));
1189 cx.run_until_parked();
1190
1191 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1192 respond_tx.send(()).unwrap();
1193
1194 assert_eq!(reject_request.rejections.len(), 20);
1195 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1196 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1197
1198 // Request failure
1199 ep_store.update(cx, |ep_store, _cx| {
1200 ep_store.reject_prediction(
1201 EditPredictionId("retry-1".into()),
1202 EditPredictionRejectReason::Discarded,
1203 false,
1204 );
1205 });
1206
1207 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1208 cx.run_until_parked();
1209
1210 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1211 assert_eq!(reject_request.rejections.len(), 1);
1212 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1213 // Simulate failure
1214 drop(_respond_tx);
1215
1216 // Add another rejection
1217 ep_store.update(cx, |ep_store, _cx| {
1218 ep_store.reject_prediction(
1219 EditPredictionId("retry-2".into()),
1220 EditPredictionRejectReason::Discarded,
1221 false,
1222 );
1223 });
1224
1225 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1226 cx.run_until_parked();
1227
1228 // Retry should include both the failed item and the new one
1229 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1230 respond_tx.send(()).unwrap();
1231
1232 assert_eq!(reject_request.rejections.len(), 2);
1233 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1234 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1235}
1236
1237// Skipped until we start including diagnostics in prompt
1238// #[gpui::test]
1239// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1240// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1241// let fs = FakeFs::new(cx.executor());
1242// fs.insert_tree(
1243// "/root",
1244// json!({
1245// "foo.md": "Hello!\nBye"
1246// }),
1247// )
1248// .await;
1249// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1250
1251// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1252// let diagnostic = lsp::Diagnostic {
1253// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1254// severity: Some(lsp::DiagnosticSeverity::ERROR),
1255// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1256// ..Default::default()
1257// };
1258
1259// project.update(cx, |project, cx| {
1260// project.lsp_store().update(cx, |lsp_store, cx| {
1261// // Create some diagnostics
1262// lsp_store
1263// .update_diagnostics(
1264// LanguageServerId(0),
1265// lsp::PublishDiagnosticsParams {
1266// uri: path_to_buffer_uri.clone(),
1267// diagnostics: vec![diagnostic],
1268// version: None,
1269// },
1270// None,
1271// language::DiagnosticSourceKind::Pushed,
1272// &[],
1273// cx,
1274// )
1275// .unwrap();
1276// });
1277// });
1278
1279// let buffer = project
1280// .update(cx, |project, cx| {
1281// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1282// project.open_buffer(path, cx)
1283// })
1284// .await
1285// .unwrap();
1286
1287// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1288// let position = snapshot.anchor_before(language::Point::new(0, 0));
1289
1290// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1291// ep_store.request_prediction(&project, &buffer, position, cx)
1292// });
1293
1294// let (request, _respond_tx) = req_rx.next().await.unwrap();
1295
1296// assert_eq!(request.diagnostic_groups.len(), 1);
1297// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1298// .unwrap();
1299// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1300// assert_eq!(
1301// value,
1302// json!({
1303// "entries": [{
1304// "range": {
1305// "start": 8,
1306// "end": 10
1307// },
1308// "diagnostic": {
1309// "source": null,
1310// "code": null,
1311// "code_description": null,
1312// "severity": 1,
1313// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1314// "markdown": null,
1315// "group_id": 0,
1316// "is_primary": true,
1317// "is_disk_based": false,
1318// "is_unnecessary": false,
1319// "source_kind": "Pushed",
1320// "data": null,
1321// "underline": true
1322// }
1323// }],
1324// "primary_ix": 0
1325// })
1326// );
1327// }
1328
1329// Generate a model response that would apply the given diff to the active file.
1330fn model_response(request: RawCompletionRequest, diff_to_apply: &str) -> RawCompletionResponse {
1331 let prompt = &request.prompt;
1332
1333 let current_marker = "<|fim_middle|>current\n";
1334 let updated_marker = "<|fim_middle|>updated\n";
1335 let cursor = "<|user_cursor|>";
1336
1337 let start_ix = current_marker.len() + prompt.find(current_marker).unwrap();
1338 let end_ix = start_ix + &prompt[start_ix..].find(updated_marker).unwrap();
1339 let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
1340 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1341
1342 RawCompletionResponse {
1343 id: Uuid::new_v4().to_string(),
1344 object: "text_completion".into(),
1345 created: 0,
1346 model: "model".into(),
1347 choices: vec![RawCompletionChoice {
1348 text: new_excerpt,
1349 finish_reason: None,
1350 }],
1351 usage: RawCompletionUsage {
1352 prompt_tokens: 0,
1353 completion_tokens: 0,
1354 total_tokens: 0,
1355 },
1356 }
1357}
1358
1359fn prompt_from_request(request: &RawCompletionRequest) -> &str {
1360 &request.prompt
1361}
1362
1363struct RequestChannels {
1364 predict:
1365 mpsc::UnboundedReceiver<(RawCompletionRequest, oneshot::Sender<RawCompletionResponse>)>,
1366 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1367}
1368
1369fn init_test_with_fake_client(
1370 cx: &mut TestAppContext,
1371) -> (Entity<EditPredictionStore>, RequestChannels) {
1372 cx.update(move |cx| {
1373 let settings_store = SettingsStore::test(cx);
1374 cx.set_global(settings_store);
1375 zlog::init_test();
1376
1377 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1378 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1379
1380 let http_client = FakeHttpClient::create({
1381 move |req| {
1382 let uri = req.uri().path().to_string();
1383 let mut body = req.into_body();
1384 let predict_req_tx = predict_req_tx.clone();
1385 let reject_req_tx = reject_req_tx.clone();
1386 async move {
1387 let resp = match uri.as_str() {
1388 "/client/llm_tokens" => serde_json::to_string(&json!({
1389 "token": "test"
1390 }))
1391 .unwrap(),
1392 "/predict_edits/raw" => {
1393 let mut buf = Vec::new();
1394 body.read_to_end(&mut buf).await.ok();
1395 let req = serde_json::from_slice(&buf).unwrap();
1396
1397 let (res_tx, res_rx) = oneshot::channel();
1398 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1399 serde_json::to_string(&res_rx.await?).unwrap()
1400 }
1401 "/predict_edits/reject" => {
1402 let mut buf = Vec::new();
1403 body.read_to_end(&mut buf).await.ok();
1404 let req = serde_json::from_slice(&buf).unwrap();
1405
1406 let (res_tx, res_rx) = oneshot::channel();
1407 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1408 serde_json::to_string(&res_rx.await?).unwrap()
1409 }
1410 _ => {
1411 panic!("Unexpected path: {}", uri)
1412 }
1413 };
1414
1415 Ok(Response::builder().body(resp.into()).unwrap())
1416 }
1417 }
1418 });
1419
1420 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1421 client.cloud_client().set_credentials(1, "test".into());
1422
1423 language_model::init(client.clone(), cx);
1424
1425 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1426 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1427
1428 (
1429 ep_store,
1430 RequestChannels {
1431 predict: predict_req_rx,
1432 reject: reject_req_rx,
1433 },
1434 )
1435 })
1436}
1437
1438const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1439
1440#[gpui::test]
1441async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1442 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1443 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1444 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1445 });
1446
1447 let edit_preview = cx
1448 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1449 .await;
1450
1451 let prediction = EditPrediction {
1452 edits,
1453 edit_preview,
1454 buffer: buffer.clone(),
1455 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1456 id: EditPredictionId("the-id".into()),
1457 inputs: ZetaPromptInput {
1458 events: Default::default(),
1459 related_files: Default::default(),
1460 cursor_path: Path::new("").into(),
1461 cursor_excerpt: "".into(),
1462 editable_range_in_excerpt: 0..0,
1463 cursor_offset_in_excerpt: 0,
1464 },
1465 buffer_snapshotted_at: Instant::now(),
1466 response_received_at: Instant::now(),
1467 };
1468
1469 cx.update(|cx| {
1470 assert_eq!(
1471 from_completion_edits(
1472 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1473 &buffer,
1474 cx
1475 ),
1476 vec![(2..5, "REM".into()), (9..11, "".into())]
1477 );
1478
1479 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1480 assert_eq!(
1481 from_completion_edits(
1482 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1483 &buffer,
1484 cx
1485 ),
1486 vec![(2..2, "REM".into()), (6..8, "".into())]
1487 );
1488
1489 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1490 assert_eq!(
1491 from_completion_edits(
1492 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1493 &buffer,
1494 cx
1495 ),
1496 vec![(2..5, "REM".into()), (9..11, "".into())]
1497 );
1498
1499 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1500 assert_eq!(
1501 from_completion_edits(
1502 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1503 &buffer,
1504 cx
1505 ),
1506 vec![(3..3, "EM".into()), (7..9, "".into())]
1507 );
1508
1509 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1510 assert_eq!(
1511 from_completion_edits(
1512 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1513 &buffer,
1514 cx
1515 ),
1516 vec![(4..4, "M".into()), (8..10, "".into())]
1517 );
1518
1519 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1520 assert_eq!(
1521 from_completion_edits(
1522 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1523 &buffer,
1524 cx
1525 ),
1526 vec![(9..11, "".into())]
1527 );
1528
1529 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1530 assert_eq!(
1531 from_completion_edits(
1532 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1533 &buffer,
1534 cx
1535 ),
1536 vec![(4..4, "M".into()), (8..10, "".into())]
1537 );
1538
1539 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1540 assert_eq!(
1541 from_completion_edits(
1542 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1543 &buffer,
1544 cx
1545 ),
1546 vec![(4..4, "M".into())]
1547 );
1548
1549 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1550 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1551 })
1552}
1553
1554#[gpui::test]
1555async fn test_clean_up_diff(cx: &mut TestAppContext) {
1556 init_test(cx);
1557
1558 assert_eq!(
1559 apply_edit_prediction(
1560 indoc! {"
1561 fn main() {
1562 let word_1 = \"lorem\";
1563 let range = word.len()..word.len();
1564 }
1565 "},
1566 indoc! {"
1567 <|editable_region_start|>
1568 fn main() {
1569 let word_1 = \"lorem\";
1570 let range = word_1.len()..word_1.len();
1571 }
1572
1573 <|editable_region_end|>
1574 "},
1575 cx,
1576 )
1577 .await,
1578 indoc! {"
1579 fn main() {
1580 let word_1 = \"lorem\";
1581 let range = word_1.len()..word_1.len();
1582 }
1583 "},
1584 );
1585
1586 assert_eq!(
1587 apply_edit_prediction(
1588 indoc! {"
1589 fn main() {
1590 let story = \"the quick\"
1591 }
1592 "},
1593 indoc! {"
1594 <|editable_region_start|>
1595 fn main() {
1596 let story = \"the quick brown fox jumps over the lazy dog\";
1597 }
1598
1599 <|editable_region_end|>
1600 "},
1601 cx,
1602 )
1603 .await,
1604 indoc! {"
1605 fn main() {
1606 let story = \"the quick brown fox jumps over the lazy dog\";
1607 }
1608 "},
1609 );
1610}
1611
1612#[gpui::test]
1613async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1614 init_test(cx);
1615
1616 let buffer_content = "lorem\n";
1617 let completion_response = indoc! {"
1618 ```animals.js
1619 <|start_of_file|>
1620 <|editable_region_start|>
1621 lorem
1622 ipsum
1623 <|editable_region_end|>
1624 ```"};
1625
1626 assert_eq!(
1627 apply_edit_prediction(buffer_content, completion_response, cx).await,
1628 "lorem\nipsum"
1629 );
1630}
1631
1632#[gpui::test]
1633async fn test_can_collect_data(cx: &mut TestAppContext) {
1634 init_test(cx);
1635
1636 let fs = project::FakeFs::new(cx.executor());
1637 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1638 .await;
1639
1640 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1641 let buffer = project
1642 .update(cx, |project, cx| {
1643 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1644 })
1645 .await
1646 .unwrap();
1647
1648 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1649 ep_store.update(cx, |ep_store, _cx| {
1650 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1651 });
1652
1653 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1654 assert_eq!(
1655 captured_request.lock().clone().unwrap().can_collect_data,
1656 true
1657 );
1658
1659 ep_store.update(cx, |ep_store, _cx| {
1660 ep_store.data_collection_choice = DataCollectionChoice::Disabled
1661 });
1662
1663 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1664 assert_eq!(
1665 captured_request.lock().clone().unwrap().can_collect_data,
1666 false
1667 );
1668}
1669
1670#[gpui::test]
1671async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1672 init_test(cx);
1673
1674 let fs = project::FakeFs::new(cx.executor());
1675 let project = Project::test(fs.clone(), [], cx).await;
1676
1677 let buffer = cx.new(|_cx| {
1678 Buffer::remote(
1679 language::BufferId::new(1).unwrap(),
1680 ReplicaId::new(1),
1681 language::Capability::ReadWrite,
1682 "fn main() {\n println!(\"Hello\");\n}",
1683 )
1684 });
1685
1686 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1687 ep_store.update(cx, |ep_store, _cx| {
1688 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1689 });
1690
1691 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1692 assert_eq!(
1693 captured_request.lock().clone().unwrap().can_collect_data,
1694 false
1695 );
1696}
1697
1698#[gpui::test]
1699async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1700 init_test(cx);
1701
1702 let fs = project::FakeFs::new(cx.executor());
1703 fs.insert_tree(
1704 path!("/project"),
1705 json!({
1706 "LICENSE": BSD_0_TXT,
1707 ".env": "SECRET_KEY=secret"
1708 }),
1709 )
1710 .await;
1711
1712 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1713 let buffer = project
1714 .update(cx, |project, cx| {
1715 project.open_local_buffer("/project/.env", cx)
1716 })
1717 .await
1718 .unwrap();
1719
1720 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1721 ep_store.update(cx, |ep_store, _cx| {
1722 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1723 });
1724
1725 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1726 assert_eq!(
1727 captured_request.lock().clone().unwrap().can_collect_data,
1728 false
1729 );
1730}
1731
1732#[gpui::test]
1733async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1734 init_test(cx);
1735
1736 let fs = project::FakeFs::new(cx.executor());
1737 let project = Project::test(fs.clone(), [], cx).await;
1738 let buffer = cx.new(|cx| Buffer::local("", cx));
1739
1740 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1741 ep_store.update(cx, |ep_store, _cx| {
1742 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1743 });
1744
1745 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1746 assert_eq!(
1747 captured_request.lock().clone().unwrap().can_collect_data,
1748 false
1749 );
1750}
1751
1752#[gpui::test]
1753async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1754 init_test(cx);
1755
1756 let fs = project::FakeFs::new(cx.executor());
1757 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1758 .await;
1759
1760 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1761 let buffer = project
1762 .update(cx, |project, cx| {
1763 project.open_local_buffer("/project/main.rs", cx)
1764 })
1765 .await
1766 .unwrap();
1767
1768 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1769 ep_store.update(cx, |ep_store, _cx| {
1770 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1771 });
1772
1773 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1774 assert_eq!(
1775 captured_request.lock().clone().unwrap().can_collect_data,
1776 false
1777 );
1778}
1779
1780#[gpui::test]
1781async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1782 init_test(cx);
1783
1784 let fs = project::FakeFs::new(cx.executor());
1785 fs.insert_tree(
1786 path!("/open_source_worktree"),
1787 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1788 )
1789 .await;
1790 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1791 .await;
1792
1793 let project = Project::test(
1794 fs.clone(),
1795 [
1796 path!("/open_source_worktree").as_ref(),
1797 path!("/closed_source_worktree").as_ref(),
1798 ],
1799 cx,
1800 )
1801 .await;
1802 let buffer = project
1803 .update(cx, |project, cx| {
1804 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1805 })
1806 .await
1807 .unwrap();
1808
1809 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1810 ep_store.update(cx, |ep_store, _cx| {
1811 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1812 });
1813
1814 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1815 assert_eq!(
1816 captured_request.lock().clone().unwrap().can_collect_data,
1817 true
1818 );
1819
1820 let closed_source_file = project
1821 .update(cx, |project, cx| {
1822 let worktree2 = project
1823 .worktree_for_root_name("closed_source_worktree", cx)
1824 .unwrap();
1825 worktree2.update(cx, |worktree2, cx| {
1826 worktree2.load_file(rel_path("main.rs"), cx)
1827 })
1828 })
1829 .await
1830 .unwrap()
1831 .file;
1832
1833 buffer.update(cx, |buffer, cx| {
1834 buffer.file_updated(closed_source_file, cx);
1835 });
1836
1837 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1838 assert_eq!(
1839 captured_request.lock().clone().unwrap().can_collect_data,
1840 false
1841 );
1842}
1843
1844#[gpui::test]
1845async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1846 init_test(cx);
1847
1848 let fs = project::FakeFs::new(cx.executor());
1849 fs.insert_tree(
1850 path!("/worktree1"),
1851 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1852 )
1853 .await;
1854 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1855 .await;
1856
1857 let project = Project::test(
1858 fs.clone(),
1859 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1860 cx,
1861 )
1862 .await;
1863 let buffer = project
1864 .update(cx, |project, cx| {
1865 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1866 })
1867 .await
1868 .unwrap();
1869 let private_buffer = project
1870 .update(cx, |project, cx| {
1871 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1872 })
1873 .await
1874 .unwrap();
1875
1876 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1877 ep_store.update(cx, |ep_store, _cx| {
1878 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1879 });
1880
1881 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1882 assert_eq!(
1883 captured_request.lock().clone().unwrap().can_collect_data,
1884 true
1885 );
1886
1887 // this has a side effect of registering the buffer to watch for edits
1888 run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1889 assert_eq!(
1890 captured_request.lock().clone().unwrap().can_collect_data,
1891 false
1892 );
1893
1894 private_buffer.update(cx, |private_buffer, cx| {
1895 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1896 });
1897
1898 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1899 assert_eq!(
1900 captured_request.lock().clone().unwrap().can_collect_data,
1901 false
1902 );
1903
1904 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1905 // included
1906 buffer.update(cx, |buffer, cx| {
1907 buffer.edit(
1908 [(
1909 0..0,
1910 " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1911 )],
1912 None,
1913 cx,
1914 );
1915 });
1916
1917 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1918 assert_eq!(
1919 captured_request.lock().clone().unwrap().can_collect_data,
1920 true
1921 );
1922}
1923
1924fn init_test(cx: &mut TestAppContext) {
1925 cx.update(|cx| {
1926 let settings_store = SettingsStore::test(cx);
1927 cx.set_global(settings_store);
1928 });
1929}
1930
1931async fn apply_edit_prediction(
1932 buffer_content: &str,
1933 completion_response: &str,
1934 cx: &mut TestAppContext,
1935) -> String {
1936 let fs = project::FakeFs::new(cx.executor());
1937 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1938 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1939 let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1940 *response.lock() = completion_response.to_string();
1941 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1942 buffer.update(cx, |buffer, cx| {
1943 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1944 });
1945 buffer.read_with(cx, |buffer, _| buffer.text())
1946}
1947
1948async fn run_edit_prediction(
1949 buffer: &Entity<Buffer>,
1950 project: &Entity<Project>,
1951 ep_store: &Entity<EditPredictionStore>,
1952 cx: &mut TestAppContext,
1953) -> EditPrediction {
1954 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1955 ep_store.update(cx, |ep_store, cx| {
1956 ep_store.register_buffer(buffer, &project, cx)
1957 });
1958 cx.background_executor.run_until_parked();
1959 let prediction_task = ep_store.update(cx, |ep_store, cx| {
1960 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
1961 });
1962 prediction_task.await.unwrap().unwrap().prediction.unwrap()
1963}
1964
1965async fn make_test_ep_store(
1966 project: &Entity<Project>,
1967 cx: &mut TestAppContext,
1968) -> (
1969 Entity<EditPredictionStore>,
1970 Arc<Mutex<Option<PredictEditsBody>>>,
1971 Arc<Mutex<String>>,
1972) {
1973 let default_response = indoc! {"
1974 ```main.rs
1975 <|start_of_file|>
1976 <|editable_region_start|>
1977 hello world
1978 <|editable_region_end|>
1979 ```"
1980 };
1981 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
1982 let completion_response: Arc<Mutex<String>> =
1983 Arc::new(Mutex::new(default_response.to_string()));
1984 let http_client = FakeHttpClient::create({
1985 let captured_request = captured_request.clone();
1986 let completion_response = completion_response.clone();
1987 let mut next_request_id = 0;
1988 move |req| {
1989 let captured_request = captured_request.clone();
1990 let completion_response = completion_response.clone();
1991 async move {
1992 match (req.method(), req.uri().path()) {
1993 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1994 .status(200)
1995 .body(
1996 serde_json::to_string(&CreateLlmTokenResponse {
1997 token: LlmToken("the-llm-token".to_string()),
1998 })
1999 .unwrap()
2000 .into(),
2001 )
2002 .unwrap()),
2003 (&Method::POST, "/predict_edits/v2") => {
2004 let mut request_body = String::new();
2005 req.into_body().read_to_string(&mut request_body).await?;
2006 *captured_request.lock() =
2007 Some(serde_json::from_str(&request_body).unwrap());
2008 next_request_id += 1;
2009 Ok(http_client::Response::builder()
2010 .status(200)
2011 .body(
2012 serde_json::to_string(&PredictEditsResponse {
2013 request_id: format!("request-{next_request_id}"),
2014 output_excerpt: completion_response.lock().clone(),
2015 })
2016 .unwrap()
2017 .into(),
2018 )
2019 .unwrap())
2020 }
2021 _ => Ok(http_client::Response::builder()
2022 .status(404)
2023 .body("Not Found".into())
2024 .unwrap()),
2025 }
2026 }
2027 }
2028 });
2029
2030 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2031 cx.update(|cx| {
2032 RefreshLlmTokenListener::register(client.clone(), cx);
2033 });
2034 let _server = FakeServer::for_client(42, &client, cx).await;
2035
2036 let ep_store = cx.new(|cx| {
2037 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2038 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2039
2040 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2041 for worktree in worktrees {
2042 let worktree_id = worktree.read(cx).id();
2043 ep_store
2044 .get_or_init_project(project, cx)
2045 .license_detection_watchers
2046 .entry(worktree_id)
2047 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2048 }
2049
2050 ep_store
2051 });
2052
2053 (ep_store, captured_request, completion_response)
2054}
2055
2056fn to_completion_edits(
2057 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2058 buffer: &Entity<Buffer>,
2059 cx: &App,
2060) -> Vec<(Range<Anchor>, Arc<str>)> {
2061 let buffer = buffer.read(cx);
2062 iterator
2063 .into_iter()
2064 .map(|(range, text)| {
2065 (
2066 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2067 text,
2068 )
2069 })
2070 .collect()
2071}
2072
2073fn from_completion_edits(
2074 editor_edits: &[(Range<Anchor>, Arc<str>)],
2075 buffer: &Entity<Buffer>,
2076 cx: &App,
2077) -> Vec<(Range<usize>, Arc<str>)> {
2078 let buffer = buffer.read(cx);
2079 editor_edits
2080 .iter()
2081 .map(|(range, text)| {
2082 (
2083 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2084 text.clone(),
2085 )
2086 })
2087 .collect()
2088}
2089
2090#[gpui::test]
2091async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2092 init_test(cx);
2093
2094 let fs = FakeFs::new(cx.executor());
2095 fs.insert_tree(
2096 "/project",
2097 serde_json::json!({
2098 "main.rs": "fn main() {\n \n}\n"
2099 }),
2100 )
2101 .await;
2102
2103 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2104
2105 let http_client = FakeHttpClient::create(|_req| async move {
2106 Ok(gpui::http_client::Response::builder()
2107 .status(401)
2108 .body("Unauthorized".into())
2109 .unwrap())
2110 });
2111
2112 let client =
2113 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2114 cx.update(|cx| {
2115 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2116 });
2117
2118 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2119
2120 let buffer = project
2121 .update(cx, |project, cx| {
2122 let path = project
2123 .find_project_path(path!("/project/main.rs"), cx)
2124 .unwrap();
2125 project.open_buffer(path, cx)
2126 })
2127 .await
2128 .unwrap();
2129
2130 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2131 ep_store.update(cx, |ep_store, cx| {
2132 ep_store.register_buffer(&buffer, &project, cx)
2133 });
2134 cx.background_executor.run_until_parked();
2135
2136 let completion_task = ep_store.update(cx, |ep_store, cx| {
2137 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2138 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2139 });
2140
2141 let result = completion_task.await;
2142 assert!(
2143 result.is_err(),
2144 "Without authentication and without custom URL, prediction should fail"
2145 );
2146}
2147
2148#[gpui::test]
2149async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2150 init_test(cx);
2151
2152 let fs = FakeFs::new(cx.executor());
2153 fs.insert_tree(
2154 "/project",
2155 serde_json::json!({
2156 "main.rs": "fn main() {\n \n}\n"
2157 }),
2158 )
2159 .await;
2160
2161 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2162
2163 let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2164 let predict_called_clone = predict_called.clone();
2165
2166 let http_client = FakeHttpClient::create({
2167 move |req| {
2168 let uri = req.uri().path().to_string();
2169 let predict_called = predict_called_clone.clone();
2170 async move {
2171 if uri.contains("predict") {
2172 predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2173 Ok(gpui::http_client::Response::builder()
2174 .body(
2175 serde_json::to_string(&open_ai::Response {
2176 id: "test-123".to_string(),
2177 object: "chat.completion".to_string(),
2178 created: 0,
2179 model: "test".to_string(),
2180 usage: open_ai::Usage {
2181 prompt_tokens: 0,
2182 completion_tokens: 0,
2183 total_tokens: 0,
2184 },
2185 choices: vec![open_ai::Choice {
2186 index: 0,
2187 message: open_ai::RequestMessage::Assistant {
2188 content: Some(open_ai::MessageContent::Plain(
2189 indoc! {"
2190 ```main.rs
2191 <|start_of_file|>
2192 <|editable_region_start|>
2193 fn main() {
2194 println!(\"Hello, world!\");
2195 }
2196 <|editable_region_end|>
2197 ```
2198 "}
2199 .to_string(),
2200 )),
2201 tool_calls: vec![],
2202 },
2203 finish_reason: Some("stop".to_string()),
2204 }],
2205 })
2206 .unwrap()
2207 .into(),
2208 )
2209 .unwrap())
2210 } else {
2211 Ok(gpui::http_client::Response::builder()
2212 .status(401)
2213 .body("Unauthorized".into())
2214 .unwrap())
2215 }
2216 }
2217 }
2218 });
2219
2220 let client =
2221 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2222 cx.update(|cx| {
2223 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2224 });
2225
2226 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2227
2228 let buffer = project
2229 .update(cx, |project, cx| {
2230 let path = project
2231 .find_project_path(path!("/project/main.rs"), cx)
2232 .unwrap();
2233 project.open_buffer(path, cx)
2234 })
2235 .await
2236 .unwrap();
2237
2238 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2239 ep_store.update(cx, |ep_store, cx| {
2240 ep_store.register_buffer(&buffer, &project, cx)
2241 });
2242 cx.background_executor.run_until_parked();
2243
2244 let completion_task = ep_store.update(cx, |ep_store, cx| {
2245 ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2246 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2247 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2248 });
2249
2250 let _ = completion_task.await;
2251
2252 assert!(
2253 predict_called.load(std::sync::atomic::Ordering::SeqCst),
2254 "With custom URL, predict endpoint should be called even without authentication"
2255 );
2256}
2257
2258#[gpui::test]
2259fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2260 let buffer = cx.new(|cx| {
2261 Buffer::local(
2262 indoc! {"
2263 zero
2264 one
2265 two
2266 three
2267 four
2268 five
2269 six
2270 seven
2271 eight
2272 nine
2273 ten
2274 eleven
2275 twelve
2276 thirteen
2277 fourteen
2278 fifteen
2279 sixteen
2280 seventeen
2281 eighteen
2282 nineteen
2283 twenty
2284 twenty-one
2285 twenty-two
2286 twenty-three
2287 twenty-four
2288 "},
2289 cx,
2290 )
2291 });
2292
2293 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2294
2295 buffer.update(cx, |buffer, cx| {
2296 let point = Point::new(12, 0);
2297 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2298 let point = Point::new(8, 0);
2299 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2300 });
2301
2302 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2303
2304 let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2305
2306 assert_eq!(
2307 diff,
2308 indoc! {"
2309 @@ -6,10 +6,12 @@
2310 five
2311 six
2312 seven
2313 +FIRST INSERTION
2314 eight
2315 nine
2316 ten
2317 eleven
2318 +SECOND INSERTION
2319 twelve
2320 thirteen
2321 fourteen
2322 "}
2323 );
2324}
2325
2326#[ctor::ctor]
2327fn init_logger() {
2328 zlog::init_test();
2329}