1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
3use client::{UserStore, test::FakeServer};
4use clock::{FakeSystemClock, ReplicaId};
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
8 RejectEditPredictionsBody,
9 predict_edits_v3::{
10 RawCompletionChoice, RawCompletionRequest, RawCompletionResponse, RawCompletionUsage,
11 },
12};
13use futures::{
14 AsyncReadExt, StreamExt,
15 channel::{mpsc, oneshot},
16};
17use gpui::{
18 Entity, TestAppContext,
19 http_client::{FakeHttpClient, Response},
20};
21use indoc::indoc;
22use language::Point;
23use lsp::LanguageServerId;
24use parking_lot::Mutex;
25use pretty_assertions::{assert_eq, assert_matches};
26use project::{FakeFs, Project};
27use serde_json::json;
28use settings::SettingsStore;
29use std::{path::Path, sync::Arc, time::Duration};
30use util::{path, rel_path::rel_path};
31use uuid::Uuid;
32use zeta_prompt::ZetaPromptInput;
33
34use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
35
36#[gpui::test]
37async fn test_current_state(cx: &mut TestAppContext) {
38 let (ep_store, mut requests) = init_test_with_fake_client(cx);
39 let fs = FakeFs::new(cx.executor());
40 fs.insert_tree(
41 "/root",
42 json!({
43 "1.txt": "Hello!\nHow\nBye\n",
44 "2.txt": "Hola!\nComo\nAdios\n"
45 }),
46 )
47 .await;
48 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
49
50 let buffer1 = project
51 .update(cx, |project, cx| {
52 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
53 project.set_active_path(Some(path.clone()), cx);
54 project.open_buffer(path, cx)
55 })
56 .await
57 .unwrap();
58 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
59 let position = snapshot1.anchor_before(language::Point::new(1, 3));
60
61 ep_store.update(cx, |ep_store, cx| {
62 ep_store.register_project(&project, cx);
63 ep_store.register_buffer(&buffer1, &project, cx);
64 });
65
66 // Prediction for current file
67
68 ep_store.update(cx, |ep_store, cx| {
69 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
70 });
71 let (request, respond_tx) = requests.predict.next().await.unwrap();
72
73 respond_tx
74 .send(model_response(
75 request,
76 indoc! {r"
77 --- a/root/1.txt
78 +++ b/root/1.txt
79 @@ ... @@
80 Hello!
81 -How
82 +How are you?
83 Bye
84 "},
85 ))
86 .unwrap();
87
88 cx.run_until_parked();
89
90 ep_store.update(cx, |ep_store, cx| {
91 let prediction = ep_store
92 .prediction_at(&buffer1, None, &project, cx)
93 .unwrap();
94 assert_matches!(prediction, BufferEditPrediction::Local { .. });
95 });
96
97 ep_store.update(cx, |ep_store, _cx| {
98 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
99 });
100
101 // Prediction for diagnostic in another file
102
103 let diagnostic = lsp::Diagnostic {
104 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
105 severity: Some(lsp::DiagnosticSeverity::ERROR),
106 message: "Sentence is incomplete".to_string(),
107 ..Default::default()
108 };
109
110 project.update(cx, |project, cx| {
111 project.lsp_store().update(cx, |lsp_store, cx| {
112 lsp_store
113 .update_diagnostics(
114 LanguageServerId(0),
115 lsp::PublishDiagnosticsParams {
116 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
117 diagnostics: vec![diagnostic],
118 version: None,
119 },
120 None,
121 language::DiagnosticSourceKind::Pushed,
122 &[],
123 cx,
124 )
125 .unwrap();
126 });
127 });
128
129 let (request, respond_tx) = requests.predict.next().await.unwrap();
130 respond_tx
131 .send(model_response(
132 request,
133 indoc! {r#"
134 --- a/root/2.txt
135 +++ b/root/2.txt
136 @@ ... @@
137 Hola!
138 -Como
139 +Como estas?
140 Adios
141 "#},
142 ))
143 .unwrap();
144 cx.run_until_parked();
145
146 ep_store.update(cx, |ep_store, cx| {
147 let prediction = ep_store
148 .prediction_at(&buffer1, None, &project, cx)
149 .unwrap();
150 assert_matches!(
151 prediction,
152 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
153 );
154 });
155
156 let buffer2 = project
157 .update(cx, |project, cx| {
158 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
159 project.open_buffer(path, cx)
160 })
161 .await
162 .unwrap();
163
164 ep_store.update(cx, |ep_store, cx| {
165 let prediction = ep_store
166 .prediction_at(&buffer2, None, &project, cx)
167 .unwrap();
168 assert_matches!(prediction, BufferEditPrediction::Local { .. });
169 });
170}
171
172#[gpui::test]
173async fn test_simple_request(cx: &mut TestAppContext) {
174 let (ep_store, mut requests) = init_test_with_fake_client(cx);
175 let fs = FakeFs::new(cx.executor());
176 fs.insert_tree(
177 "/root",
178 json!({
179 "foo.md": "Hello!\nHow\nBye\n"
180 }),
181 )
182 .await;
183 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
184
185 let buffer = project
186 .update(cx, |project, cx| {
187 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
188 project.open_buffer(path, cx)
189 })
190 .await
191 .unwrap();
192 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
193 let position = snapshot.anchor_before(language::Point::new(1, 3));
194
195 let prediction_task = ep_store.update(cx, |ep_store, cx| {
196 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
197 });
198
199 let (request, respond_tx) = requests.predict.next().await.unwrap();
200
201 // TODO Put back when we have a structured request again
202 // assert_eq!(
203 // request.excerpt_path.as_ref(),
204 // Path::new(path!("root/foo.md"))
205 // );
206 // assert_eq!(
207 // request.cursor_point,
208 // Point {
209 // line: Line(1),
210 // column: 3
211 // }
212 // );
213
214 respond_tx
215 .send(model_response(
216 request,
217 indoc! { r"
218 --- a/root/foo.md
219 +++ b/root/foo.md
220 @@ ... @@
221 Hello!
222 -How
223 +How are you?
224 Bye
225 "},
226 ))
227 .unwrap();
228
229 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
230
231 assert_eq!(prediction.edits.len(), 1);
232 assert_eq!(
233 prediction.edits[0].0.to_point(&snapshot).start,
234 language::Point::new(1, 3)
235 );
236 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
237}
238
239#[gpui::test]
240async fn test_request_events(cx: &mut TestAppContext) {
241 let (ep_store, mut requests) = init_test_with_fake_client(cx);
242 let fs = FakeFs::new(cx.executor());
243 fs.insert_tree(
244 "/root",
245 json!({
246 "foo.md": "Hello!\n\nBye\n"
247 }),
248 )
249 .await;
250 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
251
252 let buffer = project
253 .update(cx, |project, cx| {
254 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
255 project.open_buffer(path, cx)
256 })
257 .await
258 .unwrap();
259
260 ep_store.update(cx, |ep_store, cx| {
261 ep_store.register_buffer(&buffer, &project, cx);
262 });
263
264 buffer.update(cx, |buffer, cx| {
265 buffer.edit(vec![(7..7, "How")], None, cx);
266 });
267
268 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
269 let position = snapshot.anchor_before(language::Point::new(1, 3));
270
271 let prediction_task = ep_store.update(cx, |ep_store, cx| {
272 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
273 });
274
275 let (request, respond_tx) = requests.predict.next().await.unwrap();
276
277 let prompt = prompt_from_request(&request);
278 assert!(
279 prompt.contains(indoc! {"
280 --- a/root/foo.md
281 +++ b/root/foo.md
282 @@ -1,3 +1,3 @@
283 Hello!
284 -
285 +How
286 Bye
287 "}),
288 "{prompt}"
289 );
290
291 respond_tx
292 .send(model_response(
293 request,
294 indoc! {r#"
295 --- a/root/foo.md
296 +++ b/root/foo.md
297 @@ ... @@
298 Hello!
299 -How
300 +How are you?
301 Bye
302 "#},
303 ))
304 .unwrap();
305
306 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
307
308 assert_eq!(prediction.edits.len(), 1);
309 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
310}
311
312#[gpui::test]
313async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
314 let (ep_store, _requests) = init_test_with_fake_client(cx);
315 let fs = FakeFs::new(cx.executor());
316 fs.insert_tree(
317 "/root",
318 json!({
319 "foo.md": "Hello!\n\nBye\n"
320 }),
321 )
322 .await;
323 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
324
325 let buffer = project
326 .update(cx, |project, cx| {
327 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
328 project.open_buffer(path, cx)
329 })
330 .await
331 .unwrap();
332
333 ep_store.update(cx, |ep_store, cx| {
334 ep_store.register_buffer(&buffer, &project, cx);
335 });
336
337 // First burst: insert "How"
338 buffer.update(cx, |buffer, cx| {
339 buffer.edit(vec![(7..7, "How")], None, cx);
340 });
341
342 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
343 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
344 cx.run_until_parked();
345
346 // Second burst: append " are you?" immediately after "How" on the same line.
347 //
348 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
349 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
350 buffer.update(cx, |buffer, cx| {
351 buffer.edit(vec![(10..10, " are you?")], None, cx);
352 });
353
354 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
355 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
356 buffer.update(cx, |buffer, cx| {
357 buffer.edit(vec![(19..19, "!")], None, cx);
358 });
359
360 // Without time-based splitting, there is one event.
361 let events = ep_store.update(cx, |ep_store, cx| {
362 ep_store.edit_history_for_project(&project, cx)
363 });
364 assert_eq!(events.len(), 1);
365 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
366 assert_eq!(
367 diff.as_str(),
368 indoc! {"
369 @@ -1,3 +1,3 @@
370 Hello!
371 -
372 +How are you?!
373 Bye
374 "}
375 );
376
377 // With time-based splitting, there are two distinct events.
378 let events = ep_store.update(cx, |ep_store, cx| {
379 ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
380 });
381 assert_eq!(events.len(), 2);
382 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
383 assert_eq!(
384 diff.as_str(),
385 indoc! {"
386 @@ -1,3 +1,3 @@
387 Hello!
388 -
389 +How
390 Bye
391 "}
392 );
393
394 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
395 assert_eq!(
396 diff.as_str(),
397 indoc! {"
398 @@ -1,3 +1,3 @@
399 Hello!
400 -How
401 +How are you?!
402 Bye
403 "}
404 );
405}
406
407#[gpui::test]
408async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
409 let (ep_store, _requests) = init_test_with_fake_client(cx);
410 let fs = FakeFs::new(cx.executor());
411
412 // Create a file with 30 lines to test line-based coalescing
413 let content = (1..=30)
414 .map(|i| format!("Line {}\n", i))
415 .collect::<String>();
416 fs.insert_tree(
417 "/root",
418 json!({
419 "foo.md": content
420 }),
421 )
422 .await;
423 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
424
425 let buffer = project
426 .update(cx, |project, cx| {
427 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
428 project.open_buffer(path, cx)
429 })
430 .await
431 .unwrap();
432
433 ep_store.update(cx, |ep_store, cx| {
434 ep_store.register_buffer(&buffer, &project, cx);
435 });
436
437 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
438 buffer.update(cx, |buffer, cx| {
439 let start = Point::new(10, 0).to_offset(buffer);
440 let end = Point::new(13, 0).to_offset(buffer);
441 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
442 });
443
444 let events = ep_store.update(cx, |ep_store, cx| {
445 ep_store.edit_history_for_project(&project, cx)
446 });
447 assert_eq!(
448 render_events(&events),
449 indoc! {"
450 @@ -8,9 +8,8 @@
451 Line 8
452 Line 9
453 Line 10
454 -Line 11
455 -Line 12
456 -Line 13
457 +Middle A
458 +Middle B
459 Line 14
460 Line 15
461 Line 16
462 "},
463 "After first edit"
464 );
465
466 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
467 // This tests that coalescing considers the START of the existing range
468 buffer.update(cx, |buffer, cx| {
469 let offset = Point::new(5, 0).to_offset(buffer);
470 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
471 });
472
473 let events = ep_store.update(cx, |ep_store, cx| {
474 ep_store.edit_history_for_project(&project, cx)
475 });
476 assert_eq!(
477 render_events(&events),
478 indoc! {"
479 @@ -3,14 +3,14 @@
480 Line 3
481 Line 4
482 Line 5
483 +Above
484 Line 6
485 Line 7
486 Line 8
487 Line 9
488 Line 10
489 -Line 11
490 -Line 12
491 -Line 13
492 +Middle A
493 +Middle B
494 Line 14
495 Line 15
496 Line 16
497 "},
498 "After inserting above (should coalesce)"
499 );
500
501 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
502 // This tests that coalescing considers the END of the existing range
503 buffer.update(cx, |buffer, cx| {
504 let offset = Point::new(14, 0).to_offset(buffer);
505 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
506 });
507
508 let events = ep_store.update(cx, |ep_store, cx| {
509 ep_store.edit_history_for_project(&project, cx)
510 });
511 assert_eq!(
512 render_events(&events),
513 indoc! {"
514 @@ -3,15 +3,16 @@
515 Line 3
516 Line 4
517 Line 5
518 +Above
519 Line 6
520 Line 7
521 Line 8
522 Line 9
523 Line 10
524 -Line 11
525 -Line 12
526 -Line 13
527 +Middle A
528 +Middle B
529 Line 14
530 +Below
531 Line 15
532 Line 16
533 Line 17
534 "},
535 "After inserting below (should coalesce)"
536 );
537
538 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
539 // This should NOT coalesce - creates a new event
540 buffer.update(cx, |buffer, cx| {
541 let offset = Point::new(25, 0).to_offset(buffer);
542 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
543 });
544
545 let events = ep_store.update(cx, |ep_store, cx| {
546 ep_store.edit_history_for_project(&project, cx)
547 });
548 assert_eq!(
549 render_events(&events),
550 indoc! {"
551 @@ -3,15 +3,16 @@
552 Line 3
553 Line 4
554 Line 5
555 +Above
556 Line 6
557 Line 7
558 Line 8
559 Line 9
560 Line 10
561 -Line 11
562 -Line 12
563 -Line 13
564 +Middle A
565 +Middle B
566 Line 14
567 +Below
568 Line 15
569 Line 16
570 Line 17
571
572 ---
573 @@ -23,6 +23,7 @@
574 Line 22
575 Line 23
576 Line 24
577 +Far below
578 Line 25
579 Line 26
580 Line 27
581 "},
582 "After inserting far below (should NOT coalesce)"
583 );
584}
585
586fn render_events(events: &[StoredEvent]) -> String {
587 events
588 .iter()
589 .map(|e| {
590 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
591 diff.as_str()
592 })
593 .collect::<Vec<_>>()
594 .join("\n---\n")
595}
596
597#[gpui::test]
598async fn test_empty_prediction(cx: &mut TestAppContext) {
599 let (ep_store, mut requests) = init_test_with_fake_client(cx);
600 let fs = FakeFs::new(cx.executor());
601 fs.insert_tree(
602 "/root",
603 json!({
604 "foo.md": "Hello!\nHow\nBye\n"
605 }),
606 )
607 .await;
608 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
609
610 let buffer = project
611 .update(cx, |project, cx| {
612 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
613 project.open_buffer(path, cx)
614 })
615 .await
616 .unwrap();
617 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
618 let position = snapshot.anchor_before(language::Point::new(1, 3));
619
620 ep_store.update(cx, |ep_store, cx| {
621 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
622 });
623
624 let (request, respond_tx) = requests.predict.next().await.unwrap();
625 let response = model_response(request, "");
626 let id = response.id.clone();
627 respond_tx.send(response).unwrap();
628
629 cx.run_until_parked();
630
631 ep_store.update(cx, |ep_store, cx| {
632 assert!(
633 ep_store
634 .prediction_at(&buffer, None, &project, cx)
635 .is_none()
636 );
637 });
638
639 // prediction is reported as rejected
640 let (reject_request, _) = requests.reject.next().await.unwrap();
641
642 assert_eq!(
643 &reject_request.rejections,
644 &[EditPredictionRejection {
645 request_id: id,
646 reason: EditPredictionRejectReason::Empty,
647 was_shown: false
648 }]
649 );
650}
651
652#[gpui::test]
653async fn test_interpolated_empty(cx: &mut TestAppContext) {
654 let (ep_store, mut requests) = init_test_with_fake_client(cx);
655 let fs = FakeFs::new(cx.executor());
656 fs.insert_tree(
657 "/root",
658 json!({
659 "foo.md": "Hello!\nHow\nBye\n"
660 }),
661 )
662 .await;
663 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
664
665 let buffer = project
666 .update(cx, |project, cx| {
667 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
668 project.open_buffer(path, cx)
669 })
670 .await
671 .unwrap();
672 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
673 let position = snapshot.anchor_before(language::Point::new(1, 3));
674
675 ep_store.update(cx, |ep_store, cx| {
676 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
677 });
678
679 let (request, respond_tx) = requests.predict.next().await.unwrap();
680
681 buffer.update(cx, |buffer, cx| {
682 buffer.set_text("Hello!\nHow are you?\nBye", cx);
683 });
684
685 let response = model_response(request, SIMPLE_DIFF);
686 let id = response.id.clone();
687 respond_tx.send(response).unwrap();
688
689 cx.run_until_parked();
690
691 ep_store.update(cx, |ep_store, cx| {
692 assert!(
693 ep_store
694 .prediction_at(&buffer, None, &project, cx)
695 .is_none()
696 );
697 });
698
699 // prediction is reported as rejected
700 let (reject_request, _) = requests.reject.next().await.unwrap();
701
702 assert_eq!(
703 &reject_request.rejections,
704 &[EditPredictionRejection {
705 request_id: id,
706 reason: EditPredictionRejectReason::InterpolatedEmpty,
707 was_shown: false
708 }]
709 );
710}
711
712const SIMPLE_DIFF: &str = indoc! { r"
713 --- a/root/foo.md
714 +++ b/root/foo.md
715 @@ ... @@
716 Hello!
717 -How
718 +How are you?
719 Bye
720"};
721
722#[gpui::test]
723async fn test_replace_current(cx: &mut TestAppContext) {
724 let (ep_store, mut requests) = init_test_with_fake_client(cx);
725 let fs = FakeFs::new(cx.executor());
726 fs.insert_tree(
727 "/root",
728 json!({
729 "foo.md": "Hello!\nHow\nBye\n"
730 }),
731 )
732 .await;
733 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
734
735 let buffer = project
736 .update(cx, |project, cx| {
737 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
738 project.open_buffer(path, cx)
739 })
740 .await
741 .unwrap();
742 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
743 let position = snapshot.anchor_before(language::Point::new(1, 3));
744
745 ep_store.update(cx, |ep_store, cx| {
746 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
747 });
748
749 let (request, respond_tx) = requests.predict.next().await.unwrap();
750 let first_response = model_response(request, SIMPLE_DIFF);
751 let first_id = first_response.id.clone();
752 respond_tx.send(first_response).unwrap();
753
754 cx.run_until_parked();
755
756 ep_store.update(cx, |ep_store, cx| {
757 assert_eq!(
758 ep_store
759 .prediction_at(&buffer, None, &project, cx)
760 .unwrap()
761 .id
762 .0,
763 first_id
764 );
765 });
766
767 // a second request is triggered
768 ep_store.update(cx, |ep_store, cx| {
769 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
770 });
771
772 let (request, respond_tx) = requests.predict.next().await.unwrap();
773 let second_response = model_response(request, SIMPLE_DIFF);
774 let second_id = second_response.id.clone();
775 respond_tx.send(second_response).unwrap();
776
777 cx.run_until_parked();
778
779 ep_store.update(cx, |ep_store, cx| {
780 // second replaces first
781 assert_eq!(
782 ep_store
783 .prediction_at(&buffer, None, &project, cx)
784 .unwrap()
785 .id
786 .0,
787 second_id
788 );
789 });
790
791 // first is reported as replaced
792 let (reject_request, _) = requests.reject.next().await.unwrap();
793
794 assert_eq!(
795 &reject_request.rejections,
796 &[EditPredictionRejection {
797 request_id: first_id,
798 reason: EditPredictionRejectReason::Replaced,
799 was_shown: false
800 }]
801 );
802}
803
804#[gpui::test]
805async fn test_current_preferred(cx: &mut TestAppContext) {
806 let (ep_store, mut requests) = init_test_with_fake_client(cx);
807 let fs = FakeFs::new(cx.executor());
808 fs.insert_tree(
809 "/root",
810 json!({
811 "foo.md": "Hello!\nHow\nBye\n"
812 }),
813 )
814 .await;
815 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
816
817 let buffer = project
818 .update(cx, |project, cx| {
819 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
820 project.open_buffer(path, cx)
821 })
822 .await
823 .unwrap();
824 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
825 let position = snapshot.anchor_before(language::Point::new(1, 3));
826
827 ep_store.update(cx, |ep_store, cx| {
828 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
829 });
830
831 let (request, respond_tx) = requests.predict.next().await.unwrap();
832 let first_response = model_response(request, SIMPLE_DIFF);
833 let first_id = first_response.id.clone();
834 respond_tx.send(first_response).unwrap();
835
836 cx.run_until_parked();
837
838 ep_store.update(cx, |ep_store, cx| {
839 assert_eq!(
840 ep_store
841 .prediction_at(&buffer, None, &project, cx)
842 .unwrap()
843 .id
844 .0,
845 first_id
846 );
847 });
848
849 // a second request is triggered
850 ep_store.update(cx, |ep_store, cx| {
851 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
852 });
853
854 let (request, respond_tx) = requests.predict.next().await.unwrap();
855 // worse than current prediction
856 let second_response = model_response(
857 request,
858 indoc! { r"
859 --- a/root/foo.md
860 +++ b/root/foo.md
861 @@ ... @@
862 Hello!
863 -How
864 +How are
865 Bye
866 "},
867 );
868 let second_id = second_response.id.clone();
869 respond_tx.send(second_response).unwrap();
870
871 cx.run_until_parked();
872
873 ep_store.update(cx, |ep_store, cx| {
874 // first is preferred over second
875 assert_eq!(
876 ep_store
877 .prediction_at(&buffer, None, &project, cx)
878 .unwrap()
879 .id
880 .0,
881 first_id
882 );
883 });
884
885 // second is reported as rejected
886 let (reject_request, _) = requests.reject.next().await.unwrap();
887
888 assert_eq!(
889 &reject_request.rejections,
890 &[EditPredictionRejection {
891 request_id: second_id,
892 reason: EditPredictionRejectReason::CurrentPreferred,
893 was_shown: false
894 }]
895 );
896}
897
898#[gpui::test]
899async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
900 let (ep_store, mut requests) = init_test_with_fake_client(cx);
901 let fs = FakeFs::new(cx.executor());
902 fs.insert_tree(
903 "/root",
904 json!({
905 "foo.md": "Hello!\nHow\nBye\n"
906 }),
907 )
908 .await;
909 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
910
911 let buffer = project
912 .update(cx, |project, cx| {
913 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
914 project.open_buffer(path, cx)
915 })
916 .await
917 .unwrap();
918 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
919 let position = snapshot.anchor_before(language::Point::new(1, 3));
920
921 // start two refresh tasks
922 ep_store.update(cx, |ep_store, cx| {
923 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
924 });
925
926 let (request1, respond_first) = requests.predict.next().await.unwrap();
927
928 ep_store.update(cx, |ep_store, cx| {
929 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
930 });
931
932 let (request, respond_second) = requests.predict.next().await.unwrap();
933
934 // wait for throttle
935 cx.run_until_parked();
936
937 // second responds first
938 let second_response = model_response(request, SIMPLE_DIFF);
939 let second_id = second_response.id.clone();
940 respond_second.send(second_response).unwrap();
941
942 cx.run_until_parked();
943
944 ep_store.update(cx, |ep_store, cx| {
945 // current prediction is second
946 assert_eq!(
947 ep_store
948 .prediction_at(&buffer, None, &project, cx)
949 .unwrap()
950 .id
951 .0,
952 second_id
953 );
954 });
955
956 let first_response = model_response(request1, SIMPLE_DIFF);
957 let first_id = first_response.id.clone();
958 respond_first.send(first_response).unwrap();
959
960 cx.run_until_parked();
961
962 ep_store.update(cx, |ep_store, cx| {
963 // current prediction is still second, since first was cancelled
964 assert_eq!(
965 ep_store
966 .prediction_at(&buffer, None, &project, cx)
967 .unwrap()
968 .id
969 .0,
970 second_id
971 );
972 });
973
974 // first is reported as rejected
975 let (reject_request, _) = requests.reject.next().await.unwrap();
976
977 cx.run_until_parked();
978
979 assert_eq!(
980 &reject_request.rejections,
981 &[EditPredictionRejection {
982 request_id: first_id,
983 reason: EditPredictionRejectReason::Canceled,
984 was_shown: false
985 }]
986 );
987}
988
989#[gpui::test]
990async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
991 let (ep_store, mut requests) = init_test_with_fake_client(cx);
992 let fs = FakeFs::new(cx.executor());
993 fs.insert_tree(
994 "/root",
995 json!({
996 "foo.md": "Hello!\nHow\nBye\n"
997 }),
998 )
999 .await;
1000 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1001
1002 let buffer = project
1003 .update(cx, |project, cx| {
1004 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1005 project.open_buffer(path, cx)
1006 })
1007 .await
1008 .unwrap();
1009 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1010 let position = snapshot.anchor_before(language::Point::new(1, 3));
1011
1012 // start two refresh tasks
1013 ep_store.update(cx, |ep_store, cx| {
1014 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1015 });
1016
1017 let (request1, respond_first) = requests.predict.next().await.unwrap();
1018
1019 ep_store.update(cx, |ep_store, cx| {
1020 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1021 });
1022
1023 let (request2, respond_second) = requests.predict.next().await.unwrap();
1024
1025 // wait for throttle, so requests are sent
1026 cx.run_until_parked();
1027
1028 ep_store.update(cx, |ep_store, cx| {
1029 // start a third request
1030 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1031
1032 // 2 are pending, so 2nd is cancelled
1033 assert_eq!(
1034 ep_store
1035 .get_or_init_project(&project, cx)
1036 .cancelled_predictions
1037 .iter()
1038 .copied()
1039 .collect::<Vec<_>>(),
1040 [1]
1041 );
1042 });
1043
1044 // wait for throttle
1045 cx.run_until_parked();
1046
1047 let (request3, respond_third) = requests.predict.next().await.unwrap();
1048
1049 let first_response = model_response(request1, SIMPLE_DIFF);
1050 let first_id = first_response.id.clone();
1051 respond_first.send(first_response).unwrap();
1052
1053 cx.run_until_parked();
1054
1055 ep_store.update(cx, |ep_store, cx| {
1056 // current prediction is first
1057 assert_eq!(
1058 ep_store
1059 .prediction_at(&buffer, None, &project, cx)
1060 .unwrap()
1061 .id
1062 .0,
1063 first_id
1064 );
1065 });
1066
1067 let cancelled_response = model_response(request2, SIMPLE_DIFF);
1068 let cancelled_id = cancelled_response.id.clone();
1069 respond_second.send(cancelled_response).unwrap();
1070
1071 cx.run_until_parked();
1072
1073 ep_store.update(cx, |ep_store, cx| {
1074 // current prediction is still first, since second was cancelled
1075 assert_eq!(
1076 ep_store
1077 .prediction_at(&buffer, None, &project, cx)
1078 .unwrap()
1079 .id
1080 .0,
1081 first_id
1082 );
1083 });
1084
1085 let third_response = model_response(request3, SIMPLE_DIFF);
1086 let third_response_id = third_response.id.clone();
1087 respond_third.send(third_response).unwrap();
1088
1089 cx.run_until_parked();
1090
1091 ep_store.update(cx, |ep_store, cx| {
1092 // third completes and replaces first
1093 assert_eq!(
1094 ep_store
1095 .prediction_at(&buffer, None, &project, cx)
1096 .unwrap()
1097 .id
1098 .0,
1099 third_response_id
1100 );
1101 });
1102
1103 // second is reported as rejected
1104 let (reject_request, _) = requests.reject.next().await.unwrap();
1105
1106 cx.run_until_parked();
1107
1108 assert_eq!(
1109 &reject_request.rejections,
1110 &[
1111 EditPredictionRejection {
1112 request_id: cancelled_id,
1113 reason: EditPredictionRejectReason::Canceled,
1114 was_shown: false
1115 },
1116 EditPredictionRejection {
1117 request_id: first_id,
1118 reason: EditPredictionRejectReason::Replaced,
1119 was_shown: false
1120 }
1121 ]
1122 );
1123}
1124
1125#[gpui::test]
1126async fn test_rejections_flushing(cx: &mut TestAppContext) {
1127 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1128
1129 ep_store.update(cx, |ep_store, _cx| {
1130 ep_store.reject_prediction(
1131 EditPredictionId("test-1".into()),
1132 EditPredictionRejectReason::Discarded,
1133 false,
1134 );
1135 ep_store.reject_prediction(
1136 EditPredictionId("test-2".into()),
1137 EditPredictionRejectReason::Canceled,
1138 true,
1139 );
1140 });
1141
1142 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1143 cx.run_until_parked();
1144
1145 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1146 respond_tx.send(()).unwrap();
1147
1148 // batched
1149 assert_eq!(reject_request.rejections.len(), 2);
1150 assert_eq!(
1151 reject_request.rejections[0],
1152 EditPredictionRejection {
1153 request_id: "test-1".to_string(),
1154 reason: EditPredictionRejectReason::Discarded,
1155 was_shown: false
1156 }
1157 );
1158 assert_eq!(
1159 reject_request.rejections[1],
1160 EditPredictionRejection {
1161 request_id: "test-2".to_string(),
1162 reason: EditPredictionRejectReason::Canceled,
1163 was_shown: true
1164 }
1165 );
1166
1167 // Reaching batch size limit sends without debounce
1168 ep_store.update(cx, |ep_store, _cx| {
1169 for i in 0..70 {
1170 ep_store.reject_prediction(
1171 EditPredictionId(format!("batch-{}", i).into()),
1172 EditPredictionRejectReason::Discarded,
1173 false,
1174 );
1175 }
1176 });
1177
1178 // First MAX/2 items are sent immediately
1179 cx.run_until_parked();
1180 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1181 respond_tx.send(()).unwrap();
1182
1183 assert_eq!(reject_request.rejections.len(), 50);
1184 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1185 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1186
1187 // Remaining items are debounced with the next batch
1188 cx.executor().advance_clock(Duration::from_secs(15));
1189 cx.run_until_parked();
1190
1191 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1192 respond_tx.send(()).unwrap();
1193
1194 assert_eq!(reject_request.rejections.len(), 20);
1195 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1196 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1197
1198 // Request failure
1199 ep_store.update(cx, |ep_store, _cx| {
1200 ep_store.reject_prediction(
1201 EditPredictionId("retry-1".into()),
1202 EditPredictionRejectReason::Discarded,
1203 false,
1204 );
1205 });
1206
1207 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1208 cx.run_until_parked();
1209
1210 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1211 assert_eq!(reject_request.rejections.len(), 1);
1212 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1213 // Simulate failure
1214 drop(_respond_tx);
1215
1216 // Add another rejection
1217 ep_store.update(cx, |ep_store, _cx| {
1218 ep_store.reject_prediction(
1219 EditPredictionId("retry-2".into()),
1220 EditPredictionRejectReason::Discarded,
1221 false,
1222 );
1223 });
1224
1225 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1226 cx.run_until_parked();
1227
1228 // Retry should include both the failed item and the new one
1229 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1230 respond_tx.send(()).unwrap();
1231
1232 assert_eq!(reject_request.rejections.len(), 2);
1233 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1234 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1235}
1236
1237// Skipped until we start including diagnostics in prompt
1238// #[gpui::test]
1239// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1240// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1241// let fs = FakeFs::new(cx.executor());
1242// fs.insert_tree(
1243// "/root",
1244// json!({
1245// "foo.md": "Hello!\nBye"
1246// }),
1247// )
1248// .await;
1249// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1250
1251// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1252// let diagnostic = lsp::Diagnostic {
1253// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1254// severity: Some(lsp::DiagnosticSeverity::ERROR),
1255// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1256// ..Default::default()
1257// };
1258
1259// project.update(cx, |project, cx| {
1260// project.lsp_store().update(cx, |lsp_store, cx| {
1261// // Create some diagnostics
1262// lsp_store
1263// .update_diagnostics(
1264// LanguageServerId(0),
1265// lsp::PublishDiagnosticsParams {
1266// uri: path_to_buffer_uri.clone(),
1267// diagnostics: vec![diagnostic],
1268// version: None,
1269// },
1270// None,
1271// language::DiagnosticSourceKind::Pushed,
1272// &[],
1273// cx,
1274// )
1275// .unwrap();
1276// });
1277// });
1278
1279// let buffer = project
1280// .update(cx, |project, cx| {
1281// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1282// project.open_buffer(path, cx)
1283// })
1284// .await
1285// .unwrap();
1286
1287// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1288// let position = snapshot.anchor_before(language::Point::new(0, 0));
1289
1290// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1291// ep_store.request_prediction(&project, &buffer, position, cx)
1292// });
1293
1294// let (request, _respond_tx) = req_rx.next().await.unwrap();
1295
1296// assert_eq!(request.diagnostic_groups.len(), 1);
1297// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1298// .unwrap();
1299// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1300// assert_eq!(
1301// value,
1302// json!({
1303// "entries": [{
1304// "range": {
1305// "start": 8,
1306// "end": 10
1307// },
1308// "diagnostic": {
1309// "source": null,
1310// "code": null,
1311// "code_description": null,
1312// "severity": 1,
1313// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1314// "markdown": null,
1315// "group_id": 0,
1316// "is_primary": true,
1317// "is_disk_based": false,
1318// "is_unnecessary": false,
1319// "source_kind": "Pushed",
1320// "data": null,
1321// "underline": true
1322// }
1323// }],
1324// "primary_ix": 0
1325// })
1326// );
1327// }
1328
1329// Generate a model response that would apply the given diff to the active file.
1330fn model_response(request: RawCompletionRequest, diff_to_apply: &str) -> RawCompletionResponse {
1331 let prompt = &request.prompt;
1332
1333 let current_marker = "<|fim_middle|>current\n";
1334 let updated_marker = "<|fim_middle|>updated\n";
1335 let suffix_marker = "<|fim_suffix|>\n";
1336 let cursor = "<|user_cursor|>";
1337
1338 let start_ix = current_marker.len() + prompt.find(current_marker).unwrap();
1339 let end_ix = start_ix + &prompt[start_ix..].find(updated_marker).unwrap();
1340 let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
1341 // In v0113_ordered format, the excerpt contains <|fim_suffix|> and suffix content.
1342 // Strip that out to get just the editable region.
1343 let excerpt = if let Some(suffix_pos) = excerpt.find(suffix_marker) {
1344 &excerpt[..suffix_pos]
1345 } else {
1346 &excerpt
1347 };
1348 let new_excerpt = apply_diff_to_string(diff_to_apply, excerpt).unwrap();
1349
1350 RawCompletionResponse {
1351 id: Uuid::new_v4().to_string(),
1352 object: "text_completion".into(),
1353 created: 0,
1354 model: "model".into(),
1355 choices: vec![RawCompletionChoice {
1356 text: new_excerpt,
1357 finish_reason: None,
1358 }],
1359 usage: RawCompletionUsage {
1360 prompt_tokens: 0,
1361 completion_tokens: 0,
1362 total_tokens: 0,
1363 },
1364 }
1365}
1366
1367fn prompt_from_request(request: &RawCompletionRequest) -> &str {
1368 &request.prompt
1369}
1370
1371struct RequestChannels {
1372 predict:
1373 mpsc::UnboundedReceiver<(RawCompletionRequest, oneshot::Sender<RawCompletionResponse>)>,
1374 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1375}
1376
1377fn init_test_with_fake_client(
1378 cx: &mut TestAppContext,
1379) -> (Entity<EditPredictionStore>, RequestChannels) {
1380 cx.update(move |cx| {
1381 let settings_store = SettingsStore::test(cx);
1382 cx.set_global(settings_store);
1383 zlog::init_test();
1384
1385 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1386 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1387
1388 let http_client = FakeHttpClient::create({
1389 move |req| {
1390 let uri = req.uri().path().to_string();
1391 let mut body = req.into_body();
1392 let predict_req_tx = predict_req_tx.clone();
1393 let reject_req_tx = reject_req_tx.clone();
1394 async move {
1395 let resp = match uri.as_str() {
1396 "/client/llm_tokens" => serde_json::to_string(&json!({
1397 "token": "test"
1398 }))
1399 .unwrap(),
1400 "/predict_edits/raw" => {
1401 let mut buf = Vec::new();
1402 body.read_to_end(&mut buf).await.ok();
1403 let req = serde_json::from_slice(&buf).unwrap();
1404
1405 let (res_tx, res_rx) = oneshot::channel();
1406 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1407 serde_json::to_string(&res_rx.await?).unwrap()
1408 }
1409 "/predict_edits/reject" => {
1410 let mut buf = Vec::new();
1411 body.read_to_end(&mut buf).await.ok();
1412 let req = serde_json::from_slice(&buf).unwrap();
1413
1414 let (res_tx, res_rx) = oneshot::channel();
1415 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1416 serde_json::to_string(&res_rx.await?).unwrap()
1417 }
1418 _ => {
1419 panic!("Unexpected path: {}", uri)
1420 }
1421 };
1422
1423 Ok(Response::builder().body(resp.into()).unwrap())
1424 }
1425 }
1426 });
1427
1428 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1429 client.cloud_client().set_credentials(1, "test".into());
1430
1431 language_model::init(client.clone(), cx);
1432
1433 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1434 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1435
1436 (
1437 ep_store,
1438 RequestChannels {
1439 predict: predict_req_rx,
1440 reject: reject_req_rx,
1441 },
1442 )
1443 })
1444}
1445
1446const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1447
1448#[gpui::test]
1449async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1450 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1451 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1452 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1453 });
1454
1455 let edit_preview = cx
1456 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1457 .await;
1458
1459 let prediction = EditPrediction {
1460 edits,
1461 edit_preview,
1462 buffer: buffer.clone(),
1463 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1464 id: EditPredictionId("the-id".into()),
1465 inputs: ZetaPromptInput {
1466 events: Default::default(),
1467 related_files: Default::default(),
1468 cursor_path: Path::new("").into(),
1469 cursor_excerpt: "".into(),
1470 editable_range_in_excerpt: 0..0,
1471 cursor_offset_in_excerpt: 0,
1472 },
1473 buffer_snapshotted_at: Instant::now(),
1474 response_received_at: Instant::now(),
1475 };
1476
1477 cx.update(|cx| {
1478 assert_eq!(
1479 from_completion_edits(
1480 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1481 &buffer,
1482 cx
1483 ),
1484 vec![(2..5, "REM".into()), (9..11, "".into())]
1485 );
1486
1487 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1488 assert_eq!(
1489 from_completion_edits(
1490 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1491 &buffer,
1492 cx
1493 ),
1494 vec![(2..2, "REM".into()), (6..8, "".into())]
1495 );
1496
1497 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1498 assert_eq!(
1499 from_completion_edits(
1500 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1501 &buffer,
1502 cx
1503 ),
1504 vec![(2..5, "REM".into()), (9..11, "".into())]
1505 );
1506
1507 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1508 assert_eq!(
1509 from_completion_edits(
1510 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1511 &buffer,
1512 cx
1513 ),
1514 vec![(3..3, "EM".into()), (7..9, "".into())]
1515 );
1516
1517 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1518 assert_eq!(
1519 from_completion_edits(
1520 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1521 &buffer,
1522 cx
1523 ),
1524 vec![(4..4, "M".into()), (8..10, "".into())]
1525 );
1526
1527 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1528 assert_eq!(
1529 from_completion_edits(
1530 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1531 &buffer,
1532 cx
1533 ),
1534 vec![(9..11, "".into())]
1535 );
1536
1537 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1538 assert_eq!(
1539 from_completion_edits(
1540 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1541 &buffer,
1542 cx
1543 ),
1544 vec![(4..4, "M".into()), (8..10, "".into())]
1545 );
1546
1547 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1548 assert_eq!(
1549 from_completion_edits(
1550 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1551 &buffer,
1552 cx
1553 ),
1554 vec![(4..4, "M".into())]
1555 );
1556
1557 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1558 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1559 })
1560}
1561
1562#[gpui::test]
1563async fn test_clean_up_diff(cx: &mut TestAppContext) {
1564 init_test(cx);
1565
1566 assert_eq!(
1567 apply_edit_prediction(
1568 indoc! {"
1569 fn main() {
1570 let word_1 = \"lorem\";
1571 let range = word.len()..word.len();
1572 }
1573 "},
1574 indoc! {"
1575 <|editable_region_start|>
1576 fn main() {
1577 let word_1 = \"lorem\";
1578 let range = word_1.len()..word_1.len();
1579 }
1580
1581 <|editable_region_end|>
1582 "},
1583 cx,
1584 )
1585 .await,
1586 indoc! {"
1587 fn main() {
1588 let word_1 = \"lorem\";
1589 let range = word_1.len()..word_1.len();
1590 }
1591 "},
1592 );
1593
1594 assert_eq!(
1595 apply_edit_prediction(
1596 indoc! {"
1597 fn main() {
1598 let story = \"the quick\"
1599 }
1600 "},
1601 indoc! {"
1602 <|editable_region_start|>
1603 fn main() {
1604 let story = \"the quick brown fox jumps over the lazy dog\";
1605 }
1606
1607 <|editable_region_end|>
1608 "},
1609 cx,
1610 )
1611 .await,
1612 indoc! {"
1613 fn main() {
1614 let story = \"the quick brown fox jumps over the lazy dog\";
1615 }
1616 "},
1617 );
1618}
1619
1620#[gpui::test]
1621async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1622 init_test(cx);
1623
1624 let buffer_content = "lorem\n";
1625 let completion_response = indoc! {"
1626 ```animals.js
1627 <|start_of_file|>
1628 <|editable_region_start|>
1629 lorem
1630 ipsum
1631 <|editable_region_end|>
1632 ```"};
1633
1634 assert_eq!(
1635 apply_edit_prediction(buffer_content, completion_response, cx).await,
1636 "lorem\nipsum"
1637 );
1638}
1639
1640#[gpui::test]
1641async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1642 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1643 // When the buffer ends without a trailing newline, but the model returns output
1644 // with a trailing newline, zeta2 should normalize both sides before diffing
1645 // so no spurious newline is inserted.
1646 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1647 let fs = FakeFs::new(cx.executor());
1648
1649 // Single line buffer with no trailing newline
1650 fs.insert_tree(
1651 "/root",
1652 json!({
1653 "foo.txt": "hello"
1654 }),
1655 )
1656 .await;
1657 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1658
1659 let buffer = project
1660 .update(cx, |project, cx| {
1661 let path = project
1662 .find_project_path(path!("root/foo.txt"), cx)
1663 .unwrap();
1664 project.open_buffer(path, cx)
1665 })
1666 .await
1667 .unwrap();
1668
1669 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1670 let position = snapshot.anchor_before(language::Point::new(0, 5));
1671
1672 ep_store.update(cx, |ep_store, cx| {
1673 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1674 });
1675
1676 let (_request, respond_tx) = requests.predict.next().await.unwrap();
1677
1678 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1679 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1680 let response = RawCompletionResponse {
1681 id: Uuid::new_v4().to_string(),
1682 object: "text_completion".into(),
1683 created: 0,
1684 model: "model".into(),
1685 choices: vec![RawCompletionChoice {
1686 text: "hello world\n".to_string(),
1687 finish_reason: None,
1688 }],
1689 usage: RawCompletionUsage {
1690 prompt_tokens: 0,
1691 completion_tokens: 0,
1692 total_tokens: 0,
1693 },
1694 };
1695 respond_tx.send(response).unwrap();
1696
1697 cx.run_until_parked();
1698
1699 // The prediction should insert " world" without adding a newline
1700 ep_store.update(cx, |ep_store, cx| {
1701 let prediction = ep_store
1702 .prediction_at(&buffer, None, &project, cx)
1703 .expect("should have prediction");
1704 let edits: Vec<_> = prediction
1705 .edits
1706 .iter()
1707 .map(|(range, text)| {
1708 let snapshot = buffer.read(cx).snapshot();
1709 (range.to_offset(&snapshot), text.clone())
1710 })
1711 .collect();
1712 assert_eq!(edits, vec![(5..5, " world".into())]);
1713 });
1714}
1715
1716#[gpui::test]
1717async fn test_can_collect_data(cx: &mut TestAppContext) {
1718 init_test(cx);
1719
1720 let fs = project::FakeFs::new(cx.executor());
1721 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1722 .await;
1723
1724 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1725 let buffer = project
1726 .update(cx, |project, cx| {
1727 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1728 })
1729 .await
1730 .unwrap();
1731
1732 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1733 ep_store.update(cx, |ep_store, _cx| {
1734 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1735 });
1736
1737 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1738 assert_eq!(
1739 captured_request.lock().clone().unwrap().can_collect_data,
1740 true
1741 );
1742
1743 ep_store.update(cx, |ep_store, _cx| {
1744 ep_store.data_collection_choice = DataCollectionChoice::Disabled
1745 });
1746
1747 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1748 assert_eq!(
1749 captured_request.lock().clone().unwrap().can_collect_data,
1750 false
1751 );
1752}
1753
1754#[gpui::test]
1755async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1756 init_test(cx);
1757
1758 let fs = project::FakeFs::new(cx.executor());
1759 let project = Project::test(fs.clone(), [], cx).await;
1760
1761 let buffer = cx.new(|_cx| {
1762 Buffer::remote(
1763 language::BufferId::new(1).unwrap(),
1764 ReplicaId::new(1),
1765 language::Capability::ReadWrite,
1766 "fn main() {\n println!(\"Hello\");\n}",
1767 )
1768 });
1769
1770 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1771 ep_store.update(cx, |ep_store, _cx| {
1772 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1773 });
1774
1775 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1776 assert_eq!(
1777 captured_request.lock().clone().unwrap().can_collect_data,
1778 false
1779 );
1780}
1781
1782#[gpui::test]
1783async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1784 init_test(cx);
1785
1786 let fs = project::FakeFs::new(cx.executor());
1787 fs.insert_tree(
1788 path!("/project"),
1789 json!({
1790 "LICENSE": BSD_0_TXT,
1791 ".env": "SECRET_KEY=secret"
1792 }),
1793 )
1794 .await;
1795
1796 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1797 let buffer = project
1798 .update(cx, |project, cx| {
1799 project.open_local_buffer("/project/.env", cx)
1800 })
1801 .await
1802 .unwrap();
1803
1804 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1805 ep_store.update(cx, |ep_store, _cx| {
1806 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1807 });
1808
1809 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1810 assert_eq!(
1811 captured_request.lock().clone().unwrap().can_collect_data,
1812 false
1813 );
1814}
1815
1816#[gpui::test]
1817async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1818 init_test(cx);
1819
1820 let fs = project::FakeFs::new(cx.executor());
1821 let project = Project::test(fs.clone(), [], cx).await;
1822 let buffer = cx.new(|cx| Buffer::local("", cx));
1823
1824 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1825 ep_store.update(cx, |ep_store, _cx| {
1826 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1827 });
1828
1829 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1830 assert_eq!(
1831 captured_request.lock().clone().unwrap().can_collect_data,
1832 false
1833 );
1834}
1835
1836#[gpui::test]
1837async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1838 init_test(cx);
1839
1840 let fs = project::FakeFs::new(cx.executor());
1841 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1842 .await;
1843
1844 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1845 let buffer = project
1846 .update(cx, |project, cx| {
1847 project.open_local_buffer("/project/main.rs", cx)
1848 })
1849 .await
1850 .unwrap();
1851
1852 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1853 ep_store.update(cx, |ep_store, _cx| {
1854 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1855 });
1856
1857 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1858 assert_eq!(
1859 captured_request.lock().clone().unwrap().can_collect_data,
1860 false
1861 );
1862}
1863
1864#[gpui::test]
1865async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1866 init_test(cx);
1867
1868 let fs = project::FakeFs::new(cx.executor());
1869 fs.insert_tree(
1870 path!("/open_source_worktree"),
1871 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1872 )
1873 .await;
1874 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1875 .await;
1876
1877 let project = Project::test(
1878 fs.clone(),
1879 [
1880 path!("/open_source_worktree").as_ref(),
1881 path!("/closed_source_worktree").as_ref(),
1882 ],
1883 cx,
1884 )
1885 .await;
1886 let buffer = project
1887 .update(cx, |project, cx| {
1888 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1889 })
1890 .await
1891 .unwrap();
1892
1893 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1894 ep_store.update(cx, |ep_store, _cx| {
1895 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1896 });
1897
1898 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1899 assert_eq!(
1900 captured_request.lock().clone().unwrap().can_collect_data,
1901 true
1902 );
1903
1904 let closed_source_file = project
1905 .update(cx, |project, cx| {
1906 let worktree2 = project
1907 .worktree_for_root_name("closed_source_worktree", cx)
1908 .unwrap();
1909 worktree2.update(cx, |worktree2, cx| {
1910 worktree2.load_file(rel_path("main.rs"), cx)
1911 })
1912 })
1913 .await
1914 .unwrap()
1915 .file;
1916
1917 buffer.update(cx, |buffer, cx| {
1918 buffer.file_updated(closed_source_file, cx);
1919 });
1920
1921 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1922 assert_eq!(
1923 captured_request.lock().clone().unwrap().can_collect_data,
1924 false
1925 );
1926}
1927
1928#[gpui::test]
1929async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1930 init_test(cx);
1931
1932 let fs = project::FakeFs::new(cx.executor());
1933 fs.insert_tree(
1934 path!("/worktree1"),
1935 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1936 )
1937 .await;
1938 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1939 .await;
1940
1941 let project = Project::test(
1942 fs.clone(),
1943 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1944 cx,
1945 )
1946 .await;
1947 let buffer = project
1948 .update(cx, |project, cx| {
1949 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1950 })
1951 .await
1952 .unwrap();
1953 let private_buffer = project
1954 .update(cx, |project, cx| {
1955 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1956 })
1957 .await
1958 .unwrap();
1959
1960 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1961 ep_store.update(cx, |ep_store, _cx| {
1962 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1963 });
1964
1965 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1966 assert_eq!(
1967 captured_request.lock().clone().unwrap().can_collect_data,
1968 true
1969 );
1970
1971 // this has a side effect of registering the buffer to watch for edits
1972 run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1973 assert_eq!(
1974 captured_request.lock().clone().unwrap().can_collect_data,
1975 false
1976 );
1977
1978 private_buffer.update(cx, |private_buffer, cx| {
1979 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1980 });
1981
1982 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1983 assert_eq!(
1984 captured_request.lock().clone().unwrap().can_collect_data,
1985 false
1986 );
1987
1988 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1989 // included
1990 buffer.update(cx, |buffer, cx| {
1991 buffer.edit(
1992 [(
1993 0..0,
1994 " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1995 )],
1996 None,
1997 cx,
1998 );
1999 });
2000
2001 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2002 assert_eq!(
2003 captured_request.lock().clone().unwrap().can_collect_data,
2004 true
2005 );
2006}
2007
2008fn init_test(cx: &mut TestAppContext) {
2009 cx.update(|cx| {
2010 let settings_store = SettingsStore::test(cx);
2011 cx.set_global(settings_store);
2012 });
2013}
2014
2015async fn apply_edit_prediction(
2016 buffer_content: &str,
2017 completion_response: &str,
2018 cx: &mut TestAppContext,
2019) -> String {
2020 let fs = project::FakeFs::new(cx.executor());
2021 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2022 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2023 let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
2024 *response.lock() = completion_response.to_string();
2025 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2026 buffer.update(cx, |buffer, cx| {
2027 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2028 });
2029 buffer.read_with(cx, |buffer, _| buffer.text())
2030}
2031
2032async fn run_edit_prediction(
2033 buffer: &Entity<Buffer>,
2034 project: &Entity<Project>,
2035 ep_store: &Entity<EditPredictionStore>,
2036 cx: &mut TestAppContext,
2037) -> EditPrediction {
2038 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2039 ep_store.update(cx, |ep_store, cx| {
2040 ep_store.register_buffer(buffer, &project, cx)
2041 });
2042 cx.background_executor.run_until_parked();
2043 let prediction_task = ep_store.update(cx, |ep_store, cx| {
2044 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2045 });
2046 prediction_task.await.unwrap().unwrap().prediction.unwrap()
2047}
2048
2049async fn make_test_ep_store(
2050 project: &Entity<Project>,
2051 cx: &mut TestAppContext,
2052) -> (
2053 Entity<EditPredictionStore>,
2054 Arc<Mutex<Option<PredictEditsBody>>>,
2055 Arc<Mutex<String>>,
2056) {
2057 let default_response = indoc! {"
2058 ```main.rs
2059 <|start_of_file|>
2060 <|editable_region_start|>
2061 hello world
2062 <|editable_region_end|>
2063 ```"
2064 };
2065 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2066 let completion_response: Arc<Mutex<String>> =
2067 Arc::new(Mutex::new(default_response.to_string()));
2068 let http_client = FakeHttpClient::create({
2069 let captured_request = captured_request.clone();
2070 let completion_response = completion_response.clone();
2071 let mut next_request_id = 0;
2072 move |req| {
2073 let captured_request = captured_request.clone();
2074 let completion_response = completion_response.clone();
2075 async move {
2076 match (req.method(), req.uri().path()) {
2077 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2078 .status(200)
2079 .body(
2080 serde_json::to_string(&CreateLlmTokenResponse {
2081 token: LlmToken("the-llm-token".to_string()),
2082 })
2083 .unwrap()
2084 .into(),
2085 )
2086 .unwrap()),
2087 (&Method::POST, "/predict_edits/v2") => {
2088 let mut request_body = String::new();
2089 req.into_body().read_to_string(&mut request_body).await?;
2090 *captured_request.lock() =
2091 Some(serde_json::from_str(&request_body).unwrap());
2092 next_request_id += 1;
2093 Ok(http_client::Response::builder()
2094 .status(200)
2095 .body(
2096 serde_json::to_string(&PredictEditsResponse {
2097 request_id: format!("request-{next_request_id}"),
2098 output_excerpt: completion_response.lock().clone(),
2099 })
2100 .unwrap()
2101 .into(),
2102 )
2103 .unwrap())
2104 }
2105 _ => Ok(http_client::Response::builder()
2106 .status(404)
2107 .body("Not Found".into())
2108 .unwrap()),
2109 }
2110 }
2111 }
2112 });
2113
2114 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2115 cx.update(|cx| {
2116 RefreshLlmTokenListener::register(client.clone(), cx);
2117 });
2118 let _server = FakeServer::for_client(42, &client, cx).await;
2119
2120 let ep_store = cx.new(|cx| {
2121 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2122 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2123
2124 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2125 for worktree in worktrees {
2126 let worktree_id = worktree.read(cx).id();
2127 ep_store
2128 .get_or_init_project(project, cx)
2129 .license_detection_watchers
2130 .entry(worktree_id)
2131 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2132 }
2133
2134 ep_store
2135 });
2136
2137 (ep_store, captured_request, completion_response)
2138}
2139
2140fn to_completion_edits(
2141 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2142 buffer: &Entity<Buffer>,
2143 cx: &App,
2144) -> Vec<(Range<Anchor>, Arc<str>)> {
2145 let buffer = buffer.read(cx);
2146 iterator
2147 .into_iter()
2148 .map(|(range, text)| {
2149 (
2150 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2151 text,
2152 )
2153 })
2154 .collect()
2155}
2156
2157fn from_completion_edits(
2158 editor_edits: &[(Range<Anchor>, Arc<str>)],
2159 buffer: &Entity<Buffer>,
2160 cx: &App,
2161) -> Vec<(Range<usize>, Arc<str>)> {
2162 let buffer = buffer.read(cx);
2163 editor_edits
2164 .iter()
2165 .map(|(range, text)| {
2166 (
2167 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2168 text.clone(),
2169 )
2170 })
2171 .collect()
2172}
2173
2174#[gpui::test]
2175async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2176 init_test(cx);
2177
2178 let fs = FakeFs::new(cx.executor());
2179 fs.insert_tree(
2180 "/project",
2181 serde_json::json!({
2182 "main.rs": "fn main() {\n \n}\n"
2183 }),
2184 )
2185 .await;
2186
2187 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2188
2189 let http_client = FakeHttpClient::create(|_req| async move {
2190 Ok(gpui::http_client::Response::builder()
2191 .status(401)
2192 .body("Unauthorized".into())
2193 .unwrap())
2194 });
2195
2196 let client =
2197 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2198 cx.update(|cx| {
2199 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2200 });
2201
2202 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2203
2204 let buffer = project
2205 .update(cx, |project, cx| {
2206 let path = project
2207 .find_project_path(path!("/project/main.rs"), cx)
2208 .unwrap();
2209 project.open_buffer(path, cx)
2210 })
2211 .await
2212 .unwrap();
2213
2214 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2215 ep_store.update(cx, |ep_store, cx| {
2216 ep_store.register_buffer(&buffer, &project, cx)
2217 });
2218 cx.background_executor.run_until_parked();
2219
2220 let completion_task = ep_store.update(cx, |ep_store, cx| {
2221 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2222 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2223 });
2224
2225 let result = completion_task.await;
2226 assert!(
2227 result.is_err(),
2228 "Without authentication and without custom URL, prediction should fail"
2229 );
2230}
2231
2232#[gpui::test]
2233async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2234 init_test(cx);
2235
2236 let fs = FakeFs::new(cx.executor());
2237 fs.insert_tree(
2238 "/project",
2239 serde_json::json!({
2240 "main.rs": "fn main() {\n \n}\n"
2241 }),
2242 )
2243 .await;
2244
2245 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2246
2247 let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2248 let predict_called_clone = predict_called.clone();
2249
2250 let http_client = FakeHttpClient::create({
2251 move |req| {
2252 let uri = req.uri().path().to_string();
2253 let predict_called = predict_called_clone.clone();
2254 async move {
2255 if uri.contains("predict") {
2256 predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2257 Ok(gpui::http_client::Response::builder()
2258 .body(
2259 serde_json::to_string(&open_ai::Response {
2260 id: "test-123".to_string(),
2261 object: "chat.completion".to_string(),
2262 created: 0,
2263 model: "test".to_string(),
2264 usage: open_ai::Usage {
2265 prompt_tokens: 0,
2266 completion_tokens: 0,
2267 total_tokens: 0,
2268 },
2269 choices: vec![open_ai::Choice {
2270 index: 0,
2271 message: open_ai::RequestMessage::Assistant {
2272 content: Some(open_ai::MessageContent::Plain(
2273 indoc! {"
2274 ```main.rs
2275 <|start_of_file|>
2276 <|editable_region_start|>
2277 fn main() {
2278 println!(\"Hello, world!\");
2279 }
2280 <|editable_region_end|>
2281 ```
2282 "}
2283 .to_string(),
2284 )),
2285 tool_calls: vec![],
2286 },
2287 finish_reason: Some("stop".to_string()),
2288 }],
2289 })
2290 .unwrap()
2291 .into(),
2292 )
2293 .unwrap())
2294 } else {
2295 Ok(gpui::http_client::Response::builder()
2296 .status(401)
2297 .body("Unauthorized".into())
2298 .unwrap())
2299 }
2300 }
2301 }
2302 });
2303
2304 let client =
2305 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2306 cx.update(|cx| {
2307 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2308 });
2309
2310 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2311
2312 let buffer = project
2313 .update(cx, |project, cx| {
2314 let path = project
2315 .find_project_path(path!("/project/main.rs"), cx)
2316 .unwrap();
2317 project.open_buffer(path, cx)
2318 })
2319 .await
2320 .unwrap();
2321
2322 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2323 ep_store.update(cx, |ep_store, cx| {
2324 ep_store.register_buffer(&buffer, &project, cx)
2325 });
2326 cx.background_executor.run_until_parked();
2327
2328 let completion_task = ep_store.update(cx, |ep_store, cx| {
2329 ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2330 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2331 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2332 });
2333
2334 let _ = completion_task.await;
2335
2336 assert!(
2337 predict_called.load(std::sync::atomic::Ordering::SeqCst),
2338 "With custom URL, predict endpoint should be called even without authentication"
2339 );
2340}
2341
2342#[gpui::test]
2343fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2344 let buffer = cx.new(|cx| {
2345 Buffer::local(
2346 indoc! {"
2347 zero
2348 one
2349 two
2350 three
2351 four
2352 five
2353 six
2354 seven
2355 eight
2356 nine
2357 ten
2358 eleven
2359 twelve
2360 thirteen
2361 fourteen
2362 fifteen
2363 sixteen
2364 seventeen
2365 eighteen
2366 nineteen
2367 twenty
2368 twenty-one
2369 twenty-two
2370 twenty-three
2371 twenty-four
2372 "},
2373 cx,
2374 )
2375 });
2376
2377 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2378
2379 buffer.update(cx, |buffer, cx| {
2380 let point = Point::new(12, 0);
2381 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2382 let point = Point::new(8, 0);
2383 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2384 });
2385
2386 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2387
2388 let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2389
2390 assert_eq!(
2391 diff,
2392 indoc! {"
2393 @@ -6,10 +6,12 @@
2394 five
2395 six
2396 seven
2397 +FIRST INSERTION
2398 eight
2399 nine
2400 ten
2401 eleven
2402 +SECOND INSERTION
2403 twelve
2404 thirteen
2405 fourteen
2406 "}
2407 );
2408}
2409
2410#[ctor::ctor]
2411fn init_logger() {
2412 zlog::init_test();
2413}