1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
3use client::{UserStore, test::FakeServer};
4use clock::{FakeSystemClock, ReplicaId};
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
8 RejectEditPredictionsBody,
9};
10use futures::{
11 AsyncReadExt, StreamExt,
12 channel::{mpsc, oneshot},
13};
14use gpui::{
15 Entity, TestAppContext,
16 http_client::{FakeHttpClient, Response},
17};
18use indoc::indoc;
19use language::Point;
20use lsp::LanguageServerId;
21use open_ai::Usage;
22use parking_lot::Mutex;
23use pretty_assertions::{assert_eq, assert_matches};
24use project::{FakeFs, Project};
25use serde_json::json;
26use settings::SettingsStore;
27use std::{path::Path, sync::Arc, time::Duration};
28use util::{path, rel_path::rel_path};
29use uuid::Uuid;
30use zeta_prompt::ZetaPromptInput;
31
32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
33
34#[gpui::test]
35async fn test_current_state(cx: &mut TestAppContext) {
36 let (ep_store, mut requests) = init_test_with_fake_client(cx);
37 let fs = FakeFs::new(cx.executor());
38 fs.insert_tree(
39 "/root",
40 json!({
41 "1.txt": "Hello!\nHow\nBye\n",
42 "2.txt": "Hola!\nComo\nAdios\n"
43 }),
44 )
45 .await;
46 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
47
48 let buffer1 = project
49 .update(cx, |project, cx| {
50 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
51 project.set_active_path(Some(path.clone()), cx);
52 project.open_buffer(path, cx)
53 })
54 .await
55 .unwrap();
56 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
57 let position = snapshot1.anchor_before(language::Point::new(1, 3));
58
59 ep_store.update(cx, |ep_store, cx| {
60 ep_store.register_project(&project, cx);
61 ep_store.register_buffer(&buffer1, &project, cx);
62 });
63
64 // Prediction for current file
65
66 ep_store.update(cx, |ep_store, cx| {
67 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
68 });
69 let (request, respond_tx) = requests.predict.next().await.unwrap();
70
71 respond_tx
72 .send(model_response(
73 request,
74 indoc! {r"
75 --- a/root/1.txt
76 +++ b/root/1.txt
77 @@ ... @@
78 Hello!
79 -How
80 +How are you?
81 Bye
82 "},
83 ))
84 .unwrap();
85
86 cx.run_until_parked();
87
88 ep_store.update(cx, |ep_store, cx| {
89 let prediction = ep_store
90 .prediction_at(&buffer1, None, &project, cx)
91 .unwrap();
92 assert_matches!(prediction, BufferEditPrediction::Local { .. });
93 });
94
95 ep_store.update(cx, |ep_store, _cx| {
96 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
97 });
98
99 // Prediction for diagnostic in another file
100
101 let diagnostic = lsp::Diagnostic {
102 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
103 severity: Some(lsp::DiagnosticSeverity::ERROR),
104 message: "Sentence is incomplete".to_string(),
105 ..Default::default()
106 };
107
108 project.update(cx, |project, cx| {
109 project.lsp_store().update(cx, |lsp_store, cx| {
110 lsp_store
111 .update_diagnostics(
112 LanguageServerId(0),
113 lsp::PublishDiagnosticsParams {
114 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
115 diagnostics: vec![diagnostic],
116 version: None,
117 },
118 None,
119 language::DiagnosticSourceKind::Pushed,
120 &[],
121 cx,
122 )
123 .unwrap();
124 });
125 });
126
127 let (request, respond_tx) = requests.predict.next().await.unwrap();
128 respond_tx
129 .send(model_response(
130 request,
131 indoc! {r#"
132 --- a/root/2.txt
133 +++ b/root/2.txt
134 @@ ... @@
135 Hola!
136 -Como
137 +Como estas?
138 Adios
139 "#},
140 ))
141 .unwrap();
142 cx.run_until_parked();
143
144 ep_store.update(cx, |ep_store, cx| {
145 let prediction = ep_store
146 .prediction_at(&buffer1, None, &project, cx)
147 .unwrap();
148 assert_matches!(
149 prediction,
150 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
151 );
152 });
153
154 let buffer2 = project
155 .update(cx, |project, cx| {
156 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
157 project.open_buffer(path, cx)
158 })
159 .await
160 .unwrap();
161
162 ep_store.update(cx, |ep_store, cx| {
163 let prediction = ep_store
164 .prediction_at(&buffer2, None, &project, cx)
165 .unwrap();
166 assert_matches!(prediction, BufferEditPrediction::Local { .. });
167 });
168}
169
170#[gpui::test]
171async fn test_simple_request(cx: &mut TestAppContext) {
172 let (ep_store, mut requests) = init_test_with_fake_client(cx);
173 let fs = FakeFs::new(cx.executor());
174 fs.insert_tree(
175 "/root",
176 json!({
177 "foo.md": "Hello!\nHow\nBye\n"
178 }),
179 )
180 .await;
181 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
182
183 let buffer = project
184 .update(cx, |project, cx| {
185 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
186 project.open_buffer(path, cx)
187 })
188 .await
189 .unwrap();
190 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
191 let position = snapshot.anchor_before(language::Point::new(1, 3));
192
193 let prediction_task = ep_store.update(cx, |ep_store, cx| {
194 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
195 });
196
197 let (request, respond_tx) = requests.predict.next().await.unwrap();
198
199 // TODO Put back when we have a structured request again
200 // assert_eq!(
201 // request.excerpt_path.as_ref(),
202 // Path::new(path!("root/foo.md"))
203 // );
204 // assert_eq!(
205 // request.cursor_point,
206 // Point {
207 // line: Line(1),
208 // column: 3
209 // }
210 // );
211
212 respond_tx
213 .send(model_response(
214 request,
215 indoc! { r"
216 --- a/root/foo.md
217 +++ b/root/foo.md
218 @@ ... @@
219 Hello!
220 -How
221 +How are you?
222 Bye
223 "},
224 ))
225 .unwrap();
226
227 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
228
229 assert_eq!(prediction.edits.len(), 1);
230 assert_eq!(
231 prediction.edits[0].0.to_point(&snapshot).start,
232 language::Point::new(1, 3)
233 );
234 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
235}
236
237#[gpui::test]
238async fn test_request_events(cx: &mut TestAppContext) {
239 let (ep_store, mut requests) = init_test_with_fake_client(cx);
240 let fs = FakeFs::new(cx.executor());
241 fs.insert_tree(
242 "/root",
243 json!({
244 "foo.md": "Hello!\n\nBye\n"
245 }),
246 )
247 .await;
248 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
249
250 let buffer = project
251 .update(cx, |project, cx| {
252 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
253 project.open_buffer(path, cx)
254 })
255 .await
256 .unwrap();
257
258 ep_store.update(cx, |ep_store, cx| {
259 ep_store.register_buffer(&buffer, &project, cx);
260 });
261
262 buffer.update(cx, |buffer, cx| {
263 buffer.edit(vec![(7..7, "How")], None, cx);
264 });
265
266 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
267 let position = snapshot.anchor_before(language::Point::new(1, 3));
268
269 let prediction_task = ep_store.update(cx, |ep_store, cx| {
270 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
271 });
272
273 let (request, respond_tx) = requests.predict.next().await.unwrap();
274
275 let prompt = prompt_from_request(&request);
276 assert!(
277 prompt.contains(indoc! {"
278 --- a/root/foo.md
279 +++ b/root/foo.md
280 @@ -1,3 +1,3 @@
281 Hello!
282 -
283 +How
284 Bye
285 "}),
286 "{prompt}"
287 );
288
289 respond_tx
290 .send(model_response(
291 request,
292 indoc! {r#"
293 --- a/root/foo.md
294 +++ b/root/foo.md
295 @@ ... @@
296 Hello!
297 -How
298 +How are you?
299 Bye
300 "#},
301 ))
302 .unwrap();
303
304 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
305
306 assert_eq!(prediction.edits.len(), 1);
307 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
308}
309
310#[gpui::test]
311async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
312 let (ep_store, _requests) = init_test_with_fake_client(cx);
313 let fs = FakeFs::new(cx.executor());
314 fs.insert_tree(
315 "/root",
316 json!({
317 "foo.md": "Hello!\n\nBye\n"
318 }),
319 )
320 .await;
321 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
322
323 let buffer = project
324 .update(cx, |project, cx| {
325 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
326 project.open_buffer(path, cx)
327 })
328 .await
329 .unwrap();
330
331 ep_store.update(cx, |ep_store, cx| {
332 ep_store.register_buffer(&buffer, &project, cx);
333 });
334
335 // First burst: insert "How"
336 buffer.update(cx, |buffer, cx| {
337 buffer.edit(vec![(7..7, "How")], None, cx);
338 });
339
340 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
341 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
342 cx.run_until_parked();
343
344 // Second burst: append " are you?" immediately after "How" on the same line.
345 //
346 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
347 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
348 buffer.update(cx, |buffer, cx| {
349 buffer.edit(vec![(10..10, " are you?")], None, cx);
350 });
351
352 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
353 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
354 buffer.update(cx, |buffer, cx| {
355 buffer.edit(vec![(19..19, "!")], None, cx);
356 });
357
358 // Without time-based splitting, there is one event.
359 let events = ep_store.update(cx, |ep_store, cx| {
360 ep_store.edit_history_for_project(&project, cx)
361 });
362 assert_eq!(events.len(), 1);
363 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
364 assert_eq!(
365 diff.as_str(),
366 indoc! {"
367 @@ -1,3 +1,3 @@
368 Hello!
369 -
370 +How are you?!
371 Bye
372 "}
373 );
374
375 // With time-based splitting, there are two distinct events.
376 let events = ep_store.update(cx, |ep_store, cx| {
377 ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
378 });
379 assert_eq!(events.len(), 2);
380 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
381 assert_eq!(
382 diff.as_str(),
383 indoc! {"
384 @@ -1,3 +1,3 @@
385 Hello!
386 -
387 +How
388 Bye
389 "}
390 );
391
392 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
393 assert_eq!(
394 diff.as_str(),
395 indoc! {"
396 @@ -1,3 +1,3 @@
397 Hello!
398 -How
399 +How are you?!
400 Bye
401 "}
402 );
403}
404
405#[gpui::test]
406async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
407 let (ep_store, _requests) = init_test_with_fake_client(cx);
408 let fs = FakeFs::new(cx.executor());
409
410 // Create a file with 30 lines to test line-based coalescing
411 let content = (1..=30)
412 .map(|i| format!("Line {}\n", i))
413 .collect::<String>();
414 fs.insert_tree(
415 "/root",
416 json!({
417 "foo.md": content
418 }),
419 )
420 .await;
421 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
422
423 let buffer = project
424 .update(cx, |project, cx| {
425 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
426 project.open_buffer(path, cx)
427 })
428 .await
429 .unwrap();
430
431 ep_store.update(cx, |ep_store, cx| {
432 ep_store.register_buffer(&buffer, &project, cx);
433 });
434
435 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
436 buffer.update(cx, |buffer, cx| {
437 let start = Point::new(10, 0).to_offset(buffer);
438 let end = Point::new(13, 0).to_offset(buffer);
439 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
440 });
441
442 let events = ep_store.update(cx, |ep_store, cx| {
443 ep_store.edit_history_for_project(&project, cx)
444 });
445 assert_eq!(
446 render_events(&events),
447 indoc! {"
448 @@ -8,9 +8,8 @@
449 Line 8
450 Line 9
451 Line 10
452 -Line 11
453 -Line 12
454 -Line 13
455 +Middle A
456 +Middle B
457 Line 14
458 Line 15
459 Line 16
460 "},
461 "After first edit"
462 );
463
464 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
465 // This tests that coalescing considers the START of the existing range
466 buffer.update(cx, |buffer, cx| {
467 let offset = Point::new(5, 0).to_offset(buffer);
468 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
469 });
470
471 let events = ep_store.update(cx, |ep_store, cx| {
472 ep_store.edit_history_for_project(&project, cx)
473 });
474 assert_eq!(
475 render_events(&events),
476 indoc! {"
477 @@ -3,14 +3,14 @@
478 Line 3
479 Line 4
480 Line 5
481 +Above
482 Line 6
483 Line 7
484 Line 8
485 Line 9
486 Line 10
487 -Line 11
488 -Line 12
489 -Line 13
490 +Middle A
491 +Middle B
492 Line 14
493 Line 15
494 Line 16
495 "},
496 "After inserting above (should coalesce)"
497 );
498
499 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
500 // This tests that coalescing considers the END of the existing range
501 buffer.update(cx, |buffer, cx| {
502 let offset = Point::new(14, 0).to_offset(buffer);
503 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
504 });
505
506 let events = ep_store.update(cx, |ep_store, cx| {
507 ep_store.edit_history_for_project(&project, cx)
508 });
509 assert_eq!(
510 render_events(&events),
511 indoc! {"
512 @@ -3,15 +3,16 @@
513 Line 3
514 Line 4
515 Line 5
516 +Above
517 Line 6
518 Line 7
519 Line 8
520 Line 9
521 Line 10
522 -Line 11
523 -Line 12
524 -Line 13
525 +Middle A
526 +Middle B
527 Line 14
528 +Below
529 Line 15
530 Line 16
531 Line 17
532 "},
533 "After inserting below (should coalesce)"
534 );
535
536 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
537 // This should NOT coalesce - creates a new event
538 buffer.update(cx, |buffer, cx| {
539 let offset = Point::new(25, 0).to_offset(buffer);
540 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
541 });
542
543 let events = ep_store.update(cx, |ep_store, cx| {
544 ep_store.edit_history_for_project(&project, cx)
545 });
546 assert_eq!(
547 render_events(&events),
548 indoc! {"
549 @@ -3,15 +3,16 @@
550 Line 3
551 Line 4
552 Line 5
553 +Above
554 Line 6
555 Line 7
556 Line 8
557 Line 9
558 Line 10
559 -Line 11
560 -Line 12
561 -Line 13
562 +Middle A
563 +Middle B
564 Line 14
565 +Below
566 Line 15
567 Line 16
568 Line 17
569
570 ---
571 @@ -23,6 +23,7 @@
572 Line 22
573 Line 23
574 Line 24
575 +Far below
576 Line 25
577 Line 26
578 Line 27
579 "},
580 "After inserting far below (should NOT coalesce)"
581 );
582}
583
584fn render_events(events: &[StoredEvent]) -> String {
585 events
586 .iter()
587 .map(|e| {
588 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
589 diff.as_str()
590 })
591 .collect::<Vec<_>>()
592 .join("\n---\n")
593}
594
595#[gpui::test]
596async fn test_empty_prediction(cx: &mut TestAppContext) {
597 let (ep_store, mut requests) = init_test_with_fake_client(cx);
598 let fs = FakeFs::new(cx.executor());
599 fs.insert_tree(
600 "/root",
601 json!({
602 "foo.md": "Hello!\nHow\nBye\n"
603 }),
604 )
605 .await;
606 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
607
608 let buffer = project
609 .update(cx, |project, cx| {
610 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
611 project.open_buffer(path, cx)
612 })
613 .await
614 .unwrap();
615 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
616 let position = snapshot.anchor_before(language::Point::new(1, 3));
617
618 ep_store.update(cx, |ep_store, cx| {
619 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
620 });
621
622 let (request, respond_tx) = requests.predict.next().await.unwrap();
623 let response = model_response(request, "");
624 let id = response.id.clone();
625 respond_tx.send(response).unwrap();
626
627 cx.run_until_parked();
628
629 ep_store.update(cx, |ep_store, cx| {
630 assert!(
631 ep_store
632 .prediction_at(&buffer, None, &project, cx)
633 .is_none()
634 );
635 });
636
637 // prediction is reported as rejected
638 let (reject_request, _) = requests.reject.next().await.unwrap();
639
640 assert_eq!(
641 &reject_request.rejections,
642 &[EditPredictionRejection {
643 request_id: id,
644 reason: EditPredictionRejectReason::Empty,
645 was_shown: false
646 }]
647 );
648}
649
650#[gpui::test]
651async fn test_interpolated_empty(cx: &mut TestAppContext) {
652 let (ep_store, mut requests) = init_test_with_fake_client(cx);
653 let fs = FakeFs::new(cx.executor());
654 fs.insert_tree(
655 "/root",
656 json!({
657 "foo.md": "Hello!\nHow\nBye\n"
658 }),
659 )
660 .await;
661 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
662
663 let buffer = project
664 .update(cx, |project, cx| {
665 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
666 project.open_buffer(path, cx)
667 })
668 .await
669 .unwrap();
670 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
671 let position = snapshot.anchor_before(language::Point::new(1, 3));
672
673 ep_store.update(cx, |ep_store, cx| {
674 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
675 });
676
677 let (request, respond_tx) = requests.predict.next().await.unwrap();
678
679 buffer.update(cx, |buffer, cx| {
680 buffer.set_text("Hello!\nHow are you?\nBye", cx);
681 });
682
683 let response = model_response(request, SIMPLE_DIFF);
684 let id = response.id.clone();
685 respond_tx.send(response).unwrap();
686
687 cx.run_until_parked();
688
689 ep_store.update(cx, |ep_store, cx| {
690 assert!(
691 ep_store
692 .prediction_at(&buffer, None, &project, cx)
693 .is_none()
694 );
695 });
696
697 // prediction is reported as rejected
698 let (reject_request, _) = requests.reject.next().await.unwrap();
699
700 assert_eq!(
701 &reject_request.rejections,
702 &[EditPredictionRejection {
703 request_id: id,
704 reason: EditPredictionRejectReason::InterpolatedEmpty,
705 was_shown: false
706 }]
707 );
708}
709
710const SIMPLE_DIFF: &str = indoc! { r"
711 --- a/root/foo.md
712 +++ b/root/foo.md
713 @@ ... @@
714 Hello!
715 -How
716 +How are you?
717 Bye
718"};
719
720#[gpui::test]
721async fn test_replace_current(cx: &mut TestAppContext) {
722 let (ep_store, mut requests) = init_test_with_fake_client(cx);
723 let fs = FakeFs::new(cx.executor());
724 fs.insert_tree(
725 "/root",
726 json!({
727 "foo.md": "Hello!\nHow\nBye\n"
728 }),
729 )
730 .await;
731 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
732
733 let buffer = project
734 .update(cx, |project, cx| {
735 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
736 project.open_buffer(path, cx)
737 })
738 .await
739 .unwrap();
740 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
741 let position = snapshot.anchor_before(language::Point::new(1, 3));
742
743 ep_store.update(cx, |ep_store, cx| {
744 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
745 });
746
747 let (request, respond_tx) = requests.predict.next().await.unwrap();
748 let first_response = model_response(request, SIMPLE_DIFF);
749 let first_id = first_response.id.clone();
750 respond_tx.send(first_response).unwrap();
751
752 cx.run_until_parked();
753
754 ep_store.update(cx, |ep_store, cx| {
755 assert_eq!(
756 ep_store
757 .prediction_at(&buffer, None, &project, cx)
758 .unwrap()
759 .id
760 .0,
761 first_id
762 );
763 });
764
765 // a second request is triggered
766 ep_store.update(cx, |ep_store, cx| {
767 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
768 });
769
770 let (request, respond_tx) = requests.predict.next().await.unwrap();
771 let second_response = model_response(request, SIMPLE_DIFF);
772 let second_id = second_response.id.clone();
773 respond_tx.send(second_response).unwrap();
774
775 cx.run_until_parked();
776
777 ep_store.update(cx, |ep_store, cx| {
778 // second replaces first
779 assert_eq!(
780 ep_store
781 .prediction_at(&buffer, None, &project, cx)
782 .unwrap()
783 .id
784 .0,
785 second_id
786 );
787 });
788
789 // first is reported as replaced
790 let (reject_request, _) = requests.reject.next().await.unwrap();
791
792 assert_eq!(
793 &reject_request.rejections,
794 &[EditPredictionRejection {
795 request_id: first_id,
796 reason: EditPredictionRejectReason::Replaced,
797 was_shown: false
798 }]
799 );
800}
801
802#[gpui::test]
803async fn test_current_preferred(cx: &mut TestAppContext) {
804 let (ep_store, mut requests) = init_test_with_fake_client(cx);
805 let fs = FakeFs::new(cx.executor());
806 fs.insert_tree(
807 "/root",
808 json!({
809 "foo.md": "Hello!\nHow\nBye\n"
810 }),
811 )
812 .await;
813 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
814
815 let buffer = project
816 .update(cx, |project, cx| {
817 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
818 project.open_buffer(path, cx)
819 })
820 .await
821 .unwrap();
822 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
823 let position = snapshot.anchor_before(language::Point::new(1, 3));
824
825 ep_store.update(cx, |ep_store, cx| {
826 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
827 });
828
829 let (request, respond_tx) = requests.predict.next().await.unwrap();
830 let first_response = model_response(request, SIMPLE_DIFF);
831 let first_id = first_response.id.clone();
832 respond_tx.send(first_response).unwrap();
833
834 cx.run_until_parked();
835
836 ep_store.update(cx, |ep_store, cx| {
837 assert_eq!(
838 ep_store
839 .prediction_at(&buffer, None, &project, cx)
840 .unwrap()
841 .id
842 .0,
843 first_id
844 );
845 });
846
847 // a second request is triggered
848 ep_store.update(cx, |ep_store, cx| {
849 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
850 });
851
852 let (request, respond_tx) = requests.predict.next().await.unwrap();
853 // worse than current prediction
854 let second_response = model_response(
855 request,
856 indoc! { r"
857 --- a/root/foo.md
858 +++ b/root/foo.md
859 @@ ... @@
860 Hello!
861 -How
862 +How are
863 Bye
864 "},
865 );
866 let second_id = second_response.id.clone();
867 respond_tx.send(second_response).unwrap();
868
869 cx.run_until_parked();
870
871 ep_store.update(cx, |ep_store, cx| {
872 // first is preferred over second
873 assert_eq!(
874 ep_store
875 .prediction_at(&buffer, None, &project, cx)
876 .unwrap()
877 .id
878 .0,
879 first_id
880 );
881 });
882
883 // second is reported as rejected
884 let (reject_request, _) = requests.reject.next().await.unwrap();
885
886 assert_eq!(
887 &reject_request.rejections,
888 &[EditPredictionRejection {
889 request_id: second_id,
890 reason: EditPredictionRejectReason::CurrentPreferred,
891 was_shown: false
892 }]
893 );
894}
895
896#[gpui::test]
897async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
898 let (ep_store, mut requests) = init_test_with_fake_client(cx);
899 let fs = FakeFs::new(cx.executor());
900 fs.insert_tree(
901 "/root",
902 json!({
903 "foo.md": "Hello!\nHow\nBye\n"
904 }),
905 )
906 .await;
907 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
908
909 let buffer = project
910 .update(cx, |project, cx| {
911 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
912 project.open_buffer(path, cx)
913 })
914 .await
915 .unwrap();
916 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
917 let position = snapshot.anchor_before(language::Point::new(1, 3));
918
919 // start two refresh tasks
920 ep_store.update(cx, |ep_store, cx| {
921 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
922 });
923
924 let (request1, respond_first) = requests.predict.next().await.unwrap();
925
926 ep_store.update(cx, |ep_store, cx| {
927 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
928 });
929
930 let (request, respond_second) = requests.predict.next().await.unwrap();
931
932 // wait for throttle
933 cx.run_until_parked();
934
935 // second responds first
936 let second_response = model_response(request, SIMPLE_DIFF);
937 let second_id = second_response.id.clone();
938 respond_second.send(second_response).unwrap();
939
940 cx.run_until_parked();
941
942 ep_store.update(cx, |ep_store, cx| {
943 // current prediction is second
944 assert_eq!(
945 ep_store
946 .prediction_at(&buffer, None, &project, cx)
947 .unwrap()
948 .id
949 .0,
950 second_id
951 );
952 });
953
954 let first_response = model_response(request1, SIMPLE_DIFF);
955 let first_id = first_response.id.clone();
956 respond_first.send(first_response).unwrap();
957
958 cx.run_until_parked();
959
960 ep_store.update(cx, |ep_store, cx| {
961 // current prediction is still second, since first was cancelled
962 assert_eq!(
963 ep_store
964 .prediction_at(&buffer, None, &project, cx)
965 .unwrap()
966 .id
967 .0,
968 second_id
969 );
970 });
971
972 // first is reported as rejected
973 let (reject_request, _) = requests.reject.next().await.unwrap();
974
975 cx.run_until_parked();
976
977 assert_eq!(
978 &reject_request.rejections,
979 &[EditPredictionRejection {
980 request_id: first_id,
981 reason: EditPredictionRejectReason::Canceled,
982 was_shown: false
983 }]
984 );
985}
986
987#[gpui::test]
988async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
989 let (ep_store, mut requests) = init_test_with_fake_client(cx);
990 let fs = FakeFs::new(cx.executor());
991 fs.insert_tree(
992 "/root",
993 json!({
994 "foo.md": "Hello!\nHow\nBye\n"
995 }),
996 )
997 .await;
998 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
999
1000 let buffer = project
1001 .update(cx, |project, cx| {
1002 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1003 project.open_buffer(path, cx)
1004 })
1005 .await
1006 .unwrap();
1007 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1008 let position = snapshot.anchor_before(language::Point::new(1, 3));
1009
1010 // start two refresh tasks
1011 ep_store.update(cx, |ep_store, cx| {
1012 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1013 });
1014
1015 let (request1, respond_first) = requests.predict.next().await.unwrap();
1016
1017 ep_store.update(cx, |ep_store, cx| {
1018 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1019 });
1020
1021 let (request2, respond_second) = requests.predict.next().await.unwrap();
1022
1023 // wait for throttle, so requests are sent
1024 cx.run_until_parked();
1025
1026 ep_store.update(cx, |ep_store, cx| {
1027 // start a third request
1028 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1029
1030 // 2 are pending, so 2nd is cancelled
1031 assert_eq!(
1032 ep_store
1033 .get_or_init_project(&project, cx)
1034 .cancelled_predictions
1035 .iter()
1036 .copied()
1037 .collect::<Vec<_>>(),
1038 [1]
1039 );
1040 });
1041
1042 // wait for throttle
1043 cx.run_until_parked();
1044
1045 let (request3, respond_third) = requests.predict.next().await.unwrap();
1046
1047 let first_response = model_response(request1, SIMPLE_DIFF);
1048 let first_id = first_response.id.clone();
1049 respond_first.send(first_response).unwrap();
1050
1051 cx.run_until_parked();
1052
1053 ep_store.update(cx, |ep_store, cx| {
1054 // current prediction is first
1055 assert_eq!(
1056 ep_store
1057 .prediction_at(&buffer, None, &project, cx)
1058 .unwrap()
1059 .id
1060 .0,
1061 first_id
1062 );
1063 });
1064
1065 let cancelled_response = model_response(request2, SIMPLE_DIFF);
1066 let cancelled_id = cancelled_response.id.clone();
1067 respond_second.send(cancelled_response).unwrap();
1068
1069 cx.run_until_parked();
1070
1071 ep_store.update(cx, |ep_store, cx| {
1072 // current prediction is still first, since second was cancelled
1073 assert_eq!(
1074 ep_store
1075 .prediction_at(&buffer, None, &project, cx)
1076 .unwrap()
1077 .id
1078 .0,
1079 first_id
1080 );
1081 });
1082
1083 let third_response = model_response(request3, SIMPLE_DIFF);
1084 let third_response_id = third_response.id.clone();
1085 respond_third.send(third_response).unwrap();
1086
1087 cx.run_until_parked();
1088
1089 ep_store.update(cx, |ep_store, cx| {
1090 // third completes and replaces first
1091 assert_eq!(
1092 ep_store
1093 .prediction_at(&buffer, None, &project, cx)
1094 .unwrap()
1095 .id
1096 .0,
1097 third_response_id
1098 );
1099 });
1100
1101 // second is reported as rejected
1102 let (reject_request, _) = requests.reject.next().await.unwrap();
1103
1104 cx.run_until_parked();
1105
1106 assert_eq!(
1107 &reject_request.rejections,
1108 &[
1109 EditPredictionRejection {
1110 request_id: cancelled_id,
1111 reason: EditPredictionRejectReason::Canceled,
1112 was_shown: false
1113 },
1114 EditPredictionRejection {
1115 request_id: first_id,
1116 reason: EditPredictionRejectReason::Replaced,
1117 was_shown: false
1118 }
1119 ]
1120 );
1121}
1122
1123#[gpui::test]
1124async fn test_rejections_flushing(cx: &mut TestAppContext) {
1125 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1126
1127 ep_store.update(cx, |ep_store, _cx| {
1128 ep_store.reject_prediction(
1129 EditPredictionId("test-1".into()),
1130 EditPredictionRejectReason::Discarded,
1131 false,
1132 );
1133 ep_store.reject_prediction(
1134 EditPredictionId("test-2".into()),
1135 EditPredictionRejectReason::Canceled,
1136 true,
1137 );
1138 });
1139
1140 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1141 cx.run_until_parked();
1142
1143 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1144 respond_tx.send(()).unwrap();
1145
1146 // batched
1147 assert_eq!(reject_request.rejections.len(), 2);
1148 assert_eq!(
1149 reject_request.rejections[0],
1150 EditPredictionRejection {
1151 request_id: "test-1".to_string(),
1152 reason: EditPredictionRejectReason::Discarded,
1153 was_shown: false
1154 }
1155 );
1156 assert_eq!(
1157 reject_request.rejections[1],
1158 EditPredictionRejection {
1159 request_id: "test-2".to_string(),
1160 reason: EditPredictionRejectReason::Canceled,
1161 was_shown: true
1162 }
1163 );
1164
1165 // Reaching batch size limit sends without debounce
1166 ep_store.update(cx, |ep_store, _cx| {
1167 for i in 0..70 {
1168 ep_store.reject_prediction(
1169 EditPredictionId(format!("batch-{}", i).into()),
1170 EditPredictionRejectReason::Discarded,
1171 false,
1172 );
1173 }
1174 });
1175
1176 // First MAX/2 items are sent immediately
1177 cx.run_until_parked();
1178 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1179 respond_tx.send(()).unwrap();
1180
1181 assert_eq!(reject_request.rejections.len(), 50);
1182 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1183 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1184
1185 // Remaining items are debounced with the next batch
1186 cx.executor().advance_clock(Duration::from_secs(15));
1187 cx.run_until_parked();
1188
1189 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1190 respond_tx.send(()).unwrap();
1191
1192 assert_eq!(reject_request.rejections.len(), 20);
1193 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1194 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1195
1196 // Request failure
1197 ep_store.update(cx, |ep_store, _cx| {
1198 ep_store.reject_prediction(
1199 EditPredictionId("retry-1".into()),
1200 EditPredictionRejectReason::Discarded,
1201 false,
1202 );
1203 });
1204
1205 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1206 cx.run_until_parked();
1207
1208 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1209 assert_eq!(reject_request.rejections.len(), 1);
1210 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1211 // Simulate failure
1212 drop(_respond_tx);
1213
1214 // Add another rejection
1215 ep_store.update(cx, |ep_store, _cx| {
1216 ep_store.reject_prediction(
1217 EditPredictionId("retry-2".into()),
1218 EditPredictionRejectReason::Discarded,
1219 false,
1220 );
1221 });
1222
1223 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1224 cx.run_until_parked();
1225
1226 // Retry should include both the failed item and the new one
1227 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1228 respond_tx.send(()).unwrap();
1229
1230 assert_eq!(reject_request.rejections.len(), 2);
1231 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1232 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1233}
1234
1235// Skipped until we start including diagnostics in prompt
1236// #[gpui::test]
1237// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1238// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1239// let fs = FakeFs::new(cx.executor());
1240// fs.insert_tree(
1241// "/root",
1242// json!({
1243// "foo.md": "Hello!\nBye"
1244// }),
1245// )
1246// .await;
1247// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1248
1249// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1250// let diagnostic = lsp::Diagnostic {
1251// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1252// severity: Some(lsp::DiagnosticSeverity::ERROR),
1253// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1254// ..Default::default()
1255// };
1256
1257// project.update(cx, |project, cx| {
1258// project.lsp_store().update(cx, |lsp_store, cx| {
1259// // Create some diagnostics
1260// lsp_store
1261// .update_diagnostics(
1262// LanguageServerId(0),
1263// lsp::PublishDiagnosticsParams {
1264// uri: path_to_buffer_uri.clone(),
1265// diagnostics: vec![diagnostic],
1266// version: None,
1267// },
1268// None,
1269// language::DiagnosticSourceKind::Pushed,
1270// &[],
1271// cx,
1272// )
1273// .unwrap();
1274// });
1275// });
1276
1277// let buffer = project
1278// .update(cx, |project, cx| {
1279// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1280// project.open_buffer(path, cx)
1281// })
1282// .await
1283// .unwrap();
1284
1285// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1286// let position = snapshot.anchor_before(language::Point::new(0, 0));
1287
1288// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1289// ep_store.request_prediction(&project, &buffer, position, cx)
1290// });
1291
1292// let (request, _respond_tx) = req_rx.next().await.unwrap();
1293
1294// assert_eq!(request.diagnostic_groups.len(), 1);
1295// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1296// .unwrap();
1297// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1298// assert_eq!(
1299// value,
1300// json!({
1301// "entries": [{
1302// "range": {
1303// "start": 8,
1304// "end": 10
1305// },
1306// "diagnostic": {
1307// "source": null,
1308// "code": null,
1309// "code_description": null,
1310// "severity": 1,
1311// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1312// "markdown": null,
1313// "group_id": 0,
1314// "is_primary": true,
1315// "is_disk_based": false,
1316// "is_unnecessary": false,
1317// "source_kind": "Pushed",
1318// "data": null,
1319// "underline": true
1320// }
1321// }],
1322// "primary_ix": 0
1323// })
1324// );
1325// }
1326
1327// Generate a model response that would apply the given diff to the active file.
1328fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
1329 let prompt = match &request.messages[0] {
1330 open_ai::RequestMessage::User {
1331 content: open_ai::MessageContent::Plain(content),
1332 } => content,
1333 _ => panic!("unexpected request {request:?}"),
1334 };
1335
1336 let open = "<editable_region>\n";
1337 let close = "</editable_region>";
1338 let cursor = "<|user_cursor|>";
1339
1340 let start_ix = open.len() + prompt.find(open).unwrap();
1341 let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
1342 let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
1343 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1344
1345 open_ai::Response {
1346 id: Uuid::new_v4().to_string(),
1347 object: "response".into(),
1348 created: 0,
1349 model: "model".into(),
1350 choices: vec![open_ai::Choice {
1351 index: 0,
1352 message: open_ai::RequestMessage::Assistant {
1353 content: Some(open_ai::MessageContent::Plain(new_excerpt)),
1354 tool_calls: vec![],
1355 },
1356 finish_reason: None,
1357 }],
1358 usage: Usage {
1359 prompt_tokens: 0,
1360 completion_tokens: 0,
1361 total_tokens: 0,
1362 },
1363 }
1364}
1365
1366fn prompt_from_request(request: &open_ai::Request) -> &str {
1367 assert_eq!(request.messages.len(), 1);
1368 let open_ai::RequestMessage::User {
1369 content: open_ai::MessageContent::Plain(content),
1370 ..
1371 } = &request.messages[0]
1372 else {
1373 panic!(
1374 "Request does not have single user message of type Plain. {:#?}",
1375 request
1376 );
1377 };
1378 content
1379}
1380
1381struct RequestChannels {
1382 predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
1383 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1384}
1385
1386fn init_test_with_fake_client(
1387 cx: &mut TestAppContext,
1388) -> (Entity<EditPredictionStore>, RequestChannels) {
1389 cx.update(move |cx| {
1390 let settings_store = SettingsStore::test(cx);
1391 cx.set_global(settings_store);
1392 zlog::init_test();
1393
1394 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1395 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1396
1397 let http_client = FakeHttpClient::create({
1398 move |req| {
1399 let uri = req.uri().path().to_string();
1400 let mut body = req.into_body();
1401 let predict_req_tx = predict_req_tx.clone();
1402 let reject_req_tx = reject_req_tx.clone();
1403 async move {
1404 let resp = match uri.as_str() {
1405 "/client/llm_tokens" => serde_json::to_string(&json!({
1406 "token": "test"
1407 }))
1408 .unwrap(),
1409 "/predict_edits/raw" => {
1410 let mut buf = Vec::new();
1411 body.read_to_end(&mut buf).await.ok();
1412 let req = serde_json::from_slice(&buf).unwrap();
1413
1414 let (res_tx, res_rx) = oneshot::channel();
1415 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1416 serde_json::to_string(&res_rx.await?).unwrap()
1417 }
1418 "/predict_edits/reject" => {
1419 let mut buf = Vec::new();
1420 body.read_to_end(&mut buf).await.ok();
1421 let req = serde_json::from_slice(&buf).unwrap();
1422
1423 let (res_tx, res_rx) = oneshot::channel();
1424 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1425 serde_json::to_string(&res_rx.await?).unwrap()
1426 }
1427 _ => {
1428 panic!("Unexpected path: {}", uri)
1429 }
1430 };
1431
1432 Ok(Response::builder().body(resp.into()).unwrap())
1433 }
1434 }
1435 });
1436
1437 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1438 client.cloud_client().set_credentials(1, "test".into());
1439
1440 language_model::init(client.clone(), cx);
1441
1442 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1443 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1444
1445 (
1446 ep_store,
1447 RequestChannels {
1448 predict: predict_req_rx,
1449 reject: reject_req_rx,
1450 },
1451 )
1452 })
1453}
1454
1455const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1456
1457#[gpui::test]
1458async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1459 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1460 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1461 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1462 });
1463
1464 let edit_preview = cx
1465 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1466 .await;
1467
1468 let prediction = EditPrediction {
1469 edits,
1470 edit_preview,
1471 buffer: buffer.clone(),
1472 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1473 id: EditPredictionId("the-id".into()),
1474 inputs: ZetaPromptInput {
1475 events: Default::default(),
1476 related_files: Default::default(),
1477 cursor_path: Path::new("").into(),
1478 cursor_excerpt: "".into(),
1479 editable_range_in_excerpt: 0..0,
1480 cursor_offset_in_excerpt: 0,
1481 },
1482 buffer_snapshotted_at: Instant::now(),
1483 response_received_at: Instant::now(),
1484 };
1485
1486 cx.update(|cx| {
1487 assert_eq!(
1488 from_completion_edits(
1489 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1490 &buffer,
1491 cx
1492 ),
1493 vec![(2..5, "REM".into()), (9..11, "".into())]
1494 );
1495
1496 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1497 assert_eq!(
1498 from_completion_edits(
1499 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1500 &buffer,
1501 cx
1502 ),
1503 vec![(2..2, "REM".into()), (6..8, "".into())]
1504 );
1505
1506 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1507 assert_eq!(
1508 from_completion_edits(
1509 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1510 &buffer,
1511 cx
1512 ),
1513 vec![(2..5, "REM".into()), (9..11, "".into())]
1514 );
1515
1516 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1517 assert_eq!(
1518 from_completion_edits(
1519 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1520 &buffer,
1521 cx
1522 ),
1523 vec![(3..3, "EM".into()), (7..9, "".into())]
1524 );
1525
1526 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1527 assert_eq!(
1528 from_completion_edits(
1529 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1530 &buffer,
1531 cx
1532 ),
1533 vec![(4..4, "M".into()), (8..10, "".into())]
1534 );
1535
1536 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1537 assert_eq!(
1538 from_completion_edits(
1539 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1540 &buffer,
1541 cx
1542 ),
1543 vec![(9..11, "".into())]
1544 );
1545
1546 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1547 assert_eq!(
1548 from_completion_edits(
1549 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1550 &buffer,
1551 cx
1552 ),
1553 vec![(4..4, "M".into()), (8..10, "".into())]
1554 );
1555
1556 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1557 assert_eq!(
1558 from_completion_edits(
1559 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1560 &buffer,
1561 cx
1562 ),
1563 vec![(4..4, "M".into())]
1564 );
1565
1566 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1567 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1568 })
1569}
1570
1571#[gpui::test]
1572async fn test_clean_up_diff(cx: &mut TestAppContext) {
1573 init_test(cx);
1574
1575 assert_eq!(
1576 apply_edit_prediction(
1577 indoc! {"
1578 fn main() {
1579 let word_1 = \"lorem\";
1580 let range = word.len()..word.len();
1581 }
1582 "},
1583 indoc! {"
1584 <|editable_region_start|>
1585 fn main() {
1586 let word_1 = \"lorem\";
1587 let range = word_1.len()..word_1.len();
1588 }
1589
1590 <|editable_region_end|>
1591 "},
1592 cx,
1593 )
1594 .await,
1595 indoc! {"
1596 fn main() {
1597 let word_1 = \"lorem\";
1598 let range = word_1.len()..word_1.len();
1599 }
1600 "},
1601 );
1602
1603 assert_eq!(
1604 apply_edit_prediction(
1605 indoc! {"
1606 fn main() {
1607 let story = \"the quick\"
1608 }
1609 "},
1610 indoc! {"
1611 <|editable_region_start|>
1612 fn main() {
1613 let story = \"the quick brown fox jumps over the lazy dog\";
1614 }
1615
1616 <|editable_region_end|>
1617 "},
1618 cx,
1619 )
1620 .await,
1621 indoc! {"
1622 fn main() {
1623 let story = \"the quick brown fox jumps over the lazy dog\";
1624 }
1625 "},
1626 );
1627}
1628
1629#[gpui::test]
1630async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1631 init_test(cx);
1632
1633 let buffer_content = "lorem\n";
1634 let completion_response = indoc! {"
1635 ```animals.js
1636 <|start_of_file|>
1637 <|editable_region_start|>
1638 lorem
1639 ipsum
1640 <|editable_region_end|>
1641 ```"};
1642
1643 assert_eq!(
1644 apply_edit_prediction(buffer_content, completion_response, cx).await,
1645 "lorem\nipsum"
1646 );
1647}
1648
1649#[gpui::test]
1650async fn test_can_collect_data(cx: &mut TestAppContext) {
1651 init_test(cx);
1652
1653 let fs = project::FakeFs::new(cx.executor());
1654 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1655 .await;
1656
1657 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1658 let buffer = project
1659 .update(cx, |project, cx| {
1660 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1661 })
1662 .await
1663 .unwrap();
1664
1665 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1666 ep_store.update(cx, |ep_store, _cx| {
1667 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1668 });
1669
1670 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1671 assert_eq!(
1672 captured_request.lock().clone().unwrap().can_collect_data,
1673 true
1674 );
1675
1676 ep_store.update(cx, |ep_store, _cx| {
1677 ep_store.data_collection_choice = DataCollectionChoice::Disabled
1678 });
1679
1680 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1681 assert_eq!(
1682 captured_request.lock().clone().unwrap().can_collect_data,
1683 false
1684 );
1685}
1686
1687#[gpui::test]
1688async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1689 init_test(cx);
1690
1691 let fs = project::FakeFs::new(cx.executor());
1692 let project = Project::test(fs.clone(), [], cx).await;
1693
1694 let buffer = cx.new(|_cx| {
1695 Buffer::remote(
1696 language::BufferId::new(1).unwrap(),
1697 ReplicaId::new(1),
1698 language::Capability::ReadWrite,
1699 "fn main() {\n println!(\"Hello\");\n}",
1700 )
1701 });
1702
1703 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1704 ep_store.update(cx, |ep_store, _cx| {
1705 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1706 });
1707
1708 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1709 assert_eq!(
1710 captured_request.lock().clone().unwrap().can_collect_data,
1711 false
1712 );
1713}
1714
1715#[gpui::test]
1716async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1717 init_test(cx);
1718
1719 let fs = project::FakeFs::new(cx.executor());
1720 fs.insert_tree(
1721 path!("/project"),
1722 json!({
1723 "LICENSE": BSD_0_TXT,
1724 ".env": "SECRET_KEY=secret"
1725 }),
1726 )
1727 .await;
1728
1729 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1730 let buffer = project
1731 .update(cx, |project, cx| {
1732 project.open_local_buffer("/project/.env", cx)
1733 })
1734 .await
1735 .unwrap();
1736
1737 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1738 ep_store.update(cx, |ep_store, _cx| {
1739 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1740 });
1741
1742 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1743 assert_eq!(
1744 captured_request.lock().clone().unwrap().can_collect_data,
1745 false
1746 );
1747}
1748
1749#[gpui::test]
1750async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1751 init_test(cx);
1752
1753 let fs = project::FakeFs::new(cx.executor());
1754 let project = Project::test(fs.clone(), [], cx).await;
1755 let buffer = cx.new(|cx| Buffer::local("", cx));
1756
1757 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1758 ep_store.update(cx, |ep_store, _cx| {
1759 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1760 });
1761
1762 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1763 assert_eq!(
1764 captured_request.lock().clone().unwrap().can_collect_data,
1765 false
1766 );
1767}
1768
1769#[gpui::test]
1770async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1771 init_test(cx);
1772
1773 let fs = project::FakeFs::new(cx.executor());
1774 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1775 .await;
1776
1777 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1778 let buffer = project
1779 .update(cx, |project, cx| {
1780 project.open_local_buffer("/project/main.rs", cx)
1781 })
1782 .await
1783 .unwrap();
1784
1785 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1786 ep_store.update(cx, |ep_store, _cx| {
1787 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1788 });
1789
1790 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1791 assert_eq!(
1792 captured_request.lock().clone().unwrap().can_collect_data,
1793 false
1794 );
1795}
1796
1797#[gpui::test]
1798async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1799 init_test(cx);
1800
1801 let fs = project::FakeFs::new(cx.executor());
1802 fs.insert_tree(
1803 path!("/open_source_worktree"),
1804 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1805 )
1806 .await;
1807 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1808 .await;
1809
1810 let project = Project::test(
1811 fs.clone(),
1812 [
1813 path!("/open_source_worktree").as_ref(),
1814 path!("/closed_source_worktree").as_ref(),
1815 ],
1816 cx,
1817 )
1818 .await;
1819 let buffer = project
1820 .update(cx, |project, cx| {
1821 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1822 })
1823 .await
1824 .unwrap();
1825
1826 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1827 ep_store.update(cx, |ep_store, _cx| {
1828 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1829 });
1830
1831 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1832 assert_eq!(
1833 captured_request.lock().clone().unwrap().can_collect_data,
1834 true
1835 );
1836
1837 let closed_source_file = project
1838 .update(cx, |project, cx| {
1839 let worktree2 = project
1840 .worktree_for_root_name("closed_source_worktree", cx)
1841 .unwrap();
1842 worktree2.update(cx, |worktree2, cx| {
1843 worktree2.load_file(rel_path("main.rs"), cx)
1844 })
1845 })
1846 .await
1847 .unwrap()
1848 .file;
1849
1850 buffer.update(cx, |buffer, cx| {
1851 buffer.file_updated(closed_source_file, cx);
1852 });
1853
1854 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1855 assert_eq!(
1856 captured_request.lock().clone().unwrap().can_collect_data,
1857 false
1858 );
1859}
1860
1861#[gpui::test]
1862async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1863 init_test(cx);
1864
1865 let fs = project::FakeFs::new(cx.executor());
1866 fs.insert_tree(
1867 path!("/worktree1"),
1868 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1869 )
1870 .await;
1871 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1872 .await;
1873
1874 let project = Project::test(
1875 fs.clone(),
1876 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1877 cx,
1878 )
1879 .await;
1880 let buffer = project
1881 .update(cx, |project, cx| {
1882 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1883 })
1884 .await
1885 .unwrap();
1886 let private_buffer = project
1887 .update(cx, |project, cx| {
1888 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1889 })
1890 .await
1891 .unwrap();
1892
1893 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1894 ep_store.update(cx, |ep_store, _cx| {
1895 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1896 });
1897
1898 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1899 assert_eq!(
1900 captured_request.lock().clone().unwrap().can_collect_data,
1901 true
1902 );
1903
1904 // this has a side effect of registering the buffer to watch for edits
1905 run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1906 assert_eq!(
1907 captured_request.lock().clone().unwrap().can_collect_data,
1908 false
1909 );
1910
1911 private_buffer.update(cx, |private_buffer, cx| {
1912 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1913 });
1914
1915 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1916 assert_eq!(
1917 captured_request.lock().clone().unwrap().can_collect_data,
1918 false
1919 );
1920
1921 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1922 // included
1923 buffer.update(cx, |buffer, cx| {
1924 buffer.edit(
1925 [(
1926 0..0,
1927 " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1928 )],
1929 None,
1930 cx,
1931 );
1932 });
1933
1934 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1935 assert_eq!(
1936 captured_request.lock().clone().unwrap().can_collect_data,
1937 true
1938 );
1939}
1940
1941fn init_test(cx: &mut TestAppContext) {
1942 cx.update(|cx| {
1943 let settings_store = SettingsStore::test(cx);
1944 cx.set_global(settings_store);
1945 });
1946}
1947
1948async fn apply_edit_prediction(
1949 buffer_content: &str,
1950 completion_response: &str,
1951 cx: &mut TestAppContext,
1952) -> String {
1953 let fs = project::FakeFs::new(cx.executor());
1954 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1955 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1956 let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1957 *response.lock() = completion_response.to_string();
1958 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1959 buffer.update(cx, |buffer, cx| {
1960 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1961 });
1962 buffer.read_with(cx, |buffer, _| buffer.text())
1963}
1964
1965async fn run_edit_prediction(
1966 buffer: &Entity<Buffer>,
1967 project: &Entity<Project>,
1968 ep_store: &Entity<EditPredictionStore>,
1969 cx: &mut TestAppContext,
1970) -> EditPrediction {
1971 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1972 ep_store.update(cx, |ep_store, cx| {
1973 ep_store.register_buffer(buffer, &project, cx)
1974 });
1975 cx.background_executor.run_until_parked();
1976 let prediction_task = ep_store.update(cx, |ep_store, cx| {
1977 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
1978 });
1979 prediction_task.await.unwrap().unwrap().prediction.unwrap()
1980}
1981
1982async fn make_test_ep_store(
1983 project: &Entity<Project>,
1984 cx: &mut TestAppContext,
1985) -> (
1986 Entity<EditPredictionStore>,
1987 Arc<Mutex<Option<PredictEditsBody>>>,
1988 Arc<Mutex<String>>,
1989) {
1990 let default_response = indoc! {"
1991 ```main.rs
1992 <|start_of_file|>
1993 <|editable_region_start|>
1994 hello world
1995 <|editable_region_end|>
1996 ```"
1997 };
1998 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
1999 let completion_response: Arc<Mutex<String>> =
2000 Arc::new(Mutex::new(default_response.to_string()));
2001 let http_client = FakeHttpClient::create({
2002 let captured_request = captured_request.clone();
2003 let completion_response = completion_response.clone();
2004 let mut next_request_id = 0;
2005 move |req| {
2006 let captured_request = captured_request.clone();
2007 let completion_response = completion_response.clone();
2008 async move {
2009 match (req.method(), req.uri().path()) {
2010 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2011 .status(200)
2012 .body(
2013 serde_json::to_string(&CreateLlmTokenResponse {
2014 token: LlmToken("the-llm-token".to_string()),
2015 })
2016 .unwrap()
2017 .into(),
2018 )
2019 .unwrap()),
2020 (&Method::POST, "/predict_edits/v2") => {
2021 let mut request_body = String::new();
2022 req.into_body().read_to_string(&mut request_body).await?;
2023 *captured_request.lock() =
2024 Some(serde_json::from_str(&request_body).unwrap());
2025 next_request_id += 1;
2026 Ok(http_client::Response::builder()
2027 .status(200)
2028 .body(
2029 serde_json::to_string(&PredictEditsResponse {
2030 request_id: format!("request-{next_request_id}"),
2031 output_excerpt: completion_response.lock().clone(),
2032 })
2033 .unwrap()
2034 .into(),
2035 )
2036 .unwrap())
2037 }
2038 _ => Ok(http_client::Response::builder()
2039 .status(404)
2040 .body("Not Found".into())
2041 .unwrap()),
2042 }
2043 }
2044 }
2045 });
2046
2047 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2048 cx.update(|cx| {
2049 RefreshLlmTokenListener::register(client.clone(), cx);
2050 });
2051 let _server = FakeServer::for_client(42, &client, cx).await;
2052
2053 let ep_store = cx.new(|cx| {
2054 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2055 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2056
2057 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2058 for worktree in worktrees {
2059 let worktree_id = worktree.read(cx).id();
2060 ep_store
2061 .get_or_init_project(project, cx)
2062 .license_detection_watchers
2063 .entry(worktree_id)
2064 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2065 }
2066
2067 ep_store
2068 });
2069
2070 (ep_store, captured_request, completion_response)
2071}
2072
2073fn to_completion_edits(
2074 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2075 buffer: &Entity<Buffer>,
2076 cx: &App,
2077) -> Vec<(Range<Anchor>, Arc<str>)> {
2078 let buffer = buffer.read(cx);
2079 iterator
2080 .into_iter()
2081 .map(|(range, text)| {
2082 (
2083 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2084 text,
2085 )
2086 })
2087 .collect()
2088}
2089
2090fn from_completion_edits(
2091 editor_edits: &[(Range<Anchor>, Arc<str>)],
2092 buffer: &Entity<Buffer>,
2093 cx: &App,
2094) -> Vec<(Range<usize>, Arc<str>)> {
2095 let buffer = buffer.read(cx);
2096 editor_edits
2097 .iter()
2098 .map(|(range, text)| {
2099 (
2100 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2101 text.clone(),
2102 )
2103 })
2104 .collect()
2105}
2106
2107#[gpui::test]
2108async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2109 init_test(cx);
2110
2111 let fs = FakeFs::new(cx.executor());
2112 fs.insert_tree(
2113 "/project",
2114 serde_json::json!({
2115 "main.rs": "fn main() {\n \n}\n"
2116 }),
2117 )
2118 .await;
2119
2120 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2121
2122 let http_client = FakeHttpClient::create(|_req| async move {
2123 Ok(gpui::http_client::Response::builder()
2124 .status(401)
2125 .body("Unauthorized".into())
2126 .unwrap())
2127 });
2128
2129 let client =
2130 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2131 cx.update(|cx| {
2132 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2133 });
2134
2135 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2136
2137 let buffer = project
2138 .update(cx, |project, cx| {
2139 let path = project
2140 .find_project_path(path!("/project/main.rs"), cx)
2141 .unwrap();
2142 project.open_buffer(path, cx)
2143 })
2144 .await
2145 .unwrap();
2146
2147 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2148 ep_store.update(cx, |ep_store, cx| {
2149 ep_store.register_buffer(&buffer, &project, cx)
2150 });
2151 cx.background_executor.run_until_parked();
2152
2153 let completion_task = ep_store.update(cx, |ep_store, cx| {
2154 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2155 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2156 });
2157
2158 let result = completion_task.await;
2159 assert!(
2160 result.is_err(),
2161 "Without authentication and without custom URL, prediction should fail"
2162 );
2163}
2164
2165#[gpui::test]
2166async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2167 init_test(cx);
2168
2169 let fs = FakeFs::new(cx.executor());
2170 fs.insert_tree(
2171 "/project",
2172 serde_json::json!({
2173 "main.rs": "fn main() {\n \n}\n"
2174 }),
2175 )
2176 .await;
2177
2178 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2179
2180 let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2181 let predict_called_clone = predict_called.clone();
2182
2183 let http_client = FakeHttpClient::create({
2184 move |req| {
2185 let uri = req.uri().path().to_string();
2186 let predict_called = predict_called_clone.clone();
2187 async move {
2188 if uri.contains("predict") {
2189 predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2190 Ok(gpui::http_client::Response::builder()
2191 .body(
2192 serde_json::to_string(&open_ai::Response {
2193 id: "test-123".to_string(),
2194 object: "chat.completion".to_string(),
2195 created: 0,
2196 model: "test".to_string(),
2197 usage: open_ai::Usage {
2198 prompt_tokens: 0,
2199 completion_tokens: 0,
2200 total_tokens: 0,
2201 },
2202 choices: vec![open_ai::Choice {
2203 index: 0,
2204 message: open_ai::RequestMessage::Assistant {
2205 content: Some(open_ai::MessageContent::Plain(
2206 indoc! {"
2207 ```main.rs
2208 <|start_of_file|>
2209 <|editable_region_start|>
2210 fn main() {
2211 println!(\"Hello, world!\");
2212 }
2213 <|editable_region_end|>
2214 ```
2215 "}
2216 .to_string(),
2217 )),
2218 tool_calls: vec![],
2219 },
2220 finish_reason: Some("stop".to_string()),
2221 }],
2222 })
2223 .unwrap()
2224 .into(),
2225 )
2226 .unwrap())
2227 } else {
2228 Ok(gpui::http_client::Response::builder()
2229 .status(401)
2230 .body("Unauthorized".into())
2231 .unwrap())
2232 }
2233 }
2234 }
2235 });
2236
2237 let client =
2238 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2239 cx.update(|cx| {
2240 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2241 });
2242
2243 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2244
2245 let buffer = project
2246 .update(cx, |project, cx| {
2247 let path = project
2248 .find_project_path(path!("/project/main.rs"), cx)
2249 .unwrap();
2250 project.open_buffer(path, cx)
2251 })
2252 .await
2253 .unwrap();
2254
2255 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2256 ep_store.update(cx, |ep_store, cx| {
2257 ep_store.register_buffer(&buffer, &project, cx)
2258 });
2259 cx.background_executor.run_until_parked();
2260
2261 let completion_task = ep_store.update(cx, |ep_store, cx| {
2262 ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2263 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2264 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2265 });
2266
2267 let _ = completion_task.await;
2268
2269 assert!(
2270 predict_called.load(std::sync::atomic::Ordering::SeqCst),
2271 "With custom URL, predict endpoint should be called even without authentication"
2272 );
2273}
2274
2275#[gpui::test]
2276fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2277 let buffer = cx.new(|cx| {
2278 Buffer::local(
2279 indoc! {"
2280 zero
2281 one
2282 two
2283 three
2284 four
2285 five
2286 six
2287 seven
2288 eight
2289 nine
2290 ten
2291 eleven
2292 twelve
2293 thirteen
2294 fourteen
2295 fifteen
2296 sixteen
2297 seventeen
2298 eighteen
2299 nineteen
2300 twenty
2301 twenty-one
2302 twenty-two
2303 twenty-three
2304 twenty-four
2305 "},
2306 cx,
2307 )
2308 });
2309
2310 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2311
2312 buffer.update(cx, |buffer, cx| {
2313 let point = Point::new(12, 0);
2314 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2315 let point = Point::new(8, 0);
2316 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2317 });
2318
2319 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2320
2321 let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2322
2323 assert_eq!(
2324 diff,
2325 indoc! {"
2326 @@ -6,10 +6,12 @@
2327 five
2328 six
2329 seven
2330 +FIRST INSERTION
2331 eight
2332 nine
2333 ten
2334 eleven
2335 +SECOND INSERTION
2336 twelve
2337 thirteen
2338 fourteen
2339 "}
2340 );
2341}
2342
2343#[ctor::ctor]
2344fn init_logger() {
2345 zlog::init_test();
2346}