1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
3use client::{UserStore, test::FakeServer};
4use clock::{FakeSystemClock, ReplicaId};
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
8 RejectEditPredictionsBody,
9 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
10};
11use futures::{
12 AsyncReadExt, StreamExt,
13 channel::{mpsc, oneshot},
14};
15use gpui::App;
16use gpui::{
17 Entity, TestAppContext,
18 http_client::{FakeHttpClient, Response},
19};
20use indoc::indoc;
21use language::{Buffer, Point};
22use lsp::LanguageServerId;
23use parking_lot::Mutex;
24use pretty_assertions::{assert_eq, assert_matches};
25use project::{FakeFs, Project};
26use serde_json::json;
27use settings::SettingsStore;
28use std::{path::Path, sync::Arc, time::Duration};
29use util::{path, rel_path::rel_path};
30use uuid::Uuid;
31use zeta_prompt::ZetaPromptInput;
32
33use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
34
35#[gpui::test]
36async fn test_current_state(cx: &mut TestAppContext) {
37 let (ep_store, mut requests) = init_test_with_fake_client(cx);
38 let fs = FakeFs::new(cx.executor());
39 fs.insert_tree(
40 "/root",
41 json!({
42 "1.txt": "Hello!\nHow\nBye\n",
43 "2.txt": "Hola!\nComo\nAdios\n"
44 }),
45 )
46 .await;
47 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
48
49 let buffer1 = project
50 .update(cx, |project, cx| {
51 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
52 project.set_active_path(Some(path.clone()), cx);
53 project.open_buffer(path, cx)
54 })
55 .await
56 .unwrap();
57 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
58 let position = snapshot1.anchor_before(language::Point::new(1, 3));
59
60 ep_store.update(cx, |ep_store, cx| {
61 ep_store.register_project(&project, cx);
62 ep_store.register_buffer(&buffer1, &project, cx);
63 });
64
65 // Prediction for current file
66
67 ep_store.update(cx, |ep_store, cx| {
68 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
69 });
70 let (request, respond_tx) = requests.predict.next().await.unwrap();
71
72 respond_tx
73 .send(model_response(
74 &request,
75 indoc! {r"
76 --- a/root/1.txt
77 +++ b/root/1.txt
78 @@ ... @@
79 Hello!
80 -How
81 +How are you?
82 Bye
83 "},
84 ))
85 .unwrap();
86
87 cx.run_until_parked();
88
89 ep_store.update(cx, |ep_store, cx| {
90 let prediction = ep_store
91 .prediction_at(&buffer1, None, &project, cx)
92 .unwrap();
93 assert_matches!(prediction, BufferEditPrediction::Local { .. });
94 });
95
96 ep_store.update(cx, |ep_store, cx| {
97 ep_store.reject_current_prediction(
98 EditPredictionRejectReason::Discarded,
99 &project,
100 edit_prediction_types::EditPredictionDismissReason::Ignored,
101 cx,
102 );
103 });
104
105 // Prediction for diagnostic in another file
106
107 let diagnostic = lsp::Diagnostic {
108 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
109 severity: Some(lsp::DiagnosticSeverity::ERROR),
110 message: "Sentence is incomplete".to_string(),
111 ..Default::default()
112 };
113
114 project.update(cx, |project, cx| {
115 project.lsp_store().update(cx, |lsp_store, cx| {
116 lsp_store
117 .update_diagnostics(
118 LanguageServerId(0),
119 lsp::PublishDiagnosticsParams {
120 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
121 diagnostics: vec![diagnostic],
122 version: None,
123 },
124 None,
125 language::DiagnosticSourceKind::Pushed,
126 &[],
127 cx,
128 )
129 .unwrap();
130 });
131 });
132
133 let (request, respond_tx) = requests.predict.next().await.unwrap();
134 respond_tx
135 .send(model_response(
136 &request,
137 indoc! {r#"
138 --- a/root/2.txt
139 +++ b/root/2.txt
140 @@ ... @@
141 Hola!
142 -Como
143 +Como estas?
144 Adios
145 "#},
146 ))
147 .unwrap();
148 cx.run_until_parked();
149
150 ep_store.update(cx, |ep_store, cx| {
151 let prediction = ep_store
152 .prediction_at(&buffer1, None, &project, cx)
153 .unwrap();
154 assert_matches!(
155 prediction,
156 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
157 );
158 });
159
160 let buffer2 = project
161 .update(cx, |project, cx| {
162 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
163 project.open_buffer(path, cx)
164 })
165 .await
166 .unwrap();
167
168 ep_store.update(cx, |ep_store, cx| {
169 let prediction = ep_store
170 .prediction_at(&buffer2, None, &project, cx)
171 .unwrap();
172 assert_matches!(prediction, BufferEditPrediction::Local { .. });
173 });
174}
175
176#[gpui::test]
177async fn test_simple_request(cx: &mut TestAppContext) {
178 let (ep_store, mut requests) = init_test_with_fake_client(cx);
179 let fs = FakeFs::new(cx.executor());
180 fs.insert_tree(
181 "/root",
182 json!({
183 "foo.md": "Hello!\nHow\nBye\n"
184 }),
185 )
186 .await;
187 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
188
189 let buffer = project
190 .update(cx, |project, cx| {
191 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
192 project.open_buffer(path, cx)
193 })
194 .await
195 .unwrap();
196 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
197 let position = snapshot.anchor_before(language::Point::new(1, 3));
198
199 let prediction_task = ep_store.update(cx, |ep_store, cx| {
200 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
201 });
202
203 let (request, respond_tx) = requests.predict.next().await.unwrap();
204
205 // TODO Put back when we have a structured request again
206 // assert_eq!(
207 // request.excerpt_path.as_ref(),
208 // Path::new(path!("root/foo.md"))
209 // );
210 // assert_eq!(
211 // request.cursor_point,
212 // Point {
213 // line: Line(1),
214 // column: 3
215 // }
216 // );
217
218 respond_tx
219 .send(model_response(
220 &request,
221 indoc! { r"
222 --- a/root/foo.md
223 +++ b/root/foo.md
224 @@ ... @@
225 Hello!
226 -How
227 +How are you?
228 Bye
229 "},
230 ))
231 .unwrap();
232
233 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
234
235 assert_eq!(prediction.edits.len(), 1);
236 assert_eq!(
237 prediction.edits[0].0.to_point(&snapshot).start,
238 language::Point::new(1, 3)
239 );
240 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
241}
242
243#[gpui::test]
244async fn test_request_events(cx: &mut TestAppContext) {
245 let (ep_store, mut requests) = init_test_with_fake_client(cx);
246 let fs = FakeFs::new(cx.executor());
247 fs.insert_tree(
248 "/root",
249 json!({
250 "foo.md": "Hello!\n\nBye\n"
251 }),
252 )
253 .await;
254 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
255
256 let buffer = project
257 .update(cx, |project, cx| {
258 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
259 project.open_buffer(path, cx)
260 })
261 .await
262 .unwrap();
263
264 ep_store.update(cx, |ep_store, cx| {
265 ep_store.register_buffer(&buffer, &project, cx);
266 });
267
268 buffer.update(cx, |buffer, cx| {
269 buffer.edit(vec![(7..7, "How")], None, cx);
270 });
271
272 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
273 let position = snapshot.anchor_before(language::Point::new(1, 3));
274
275 let prediction_task = ep_store.update(cx, |ep_store, cx| {
276 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
277 });
278
279 let (request, respond_tx) = requests.predict.next().await.unwrap();
280
281 let prompt = prompt_from_request(&request);
282 assert!(
283 prompt.contains(indoc! {"
284 --- a/root/foo.md
285 +++ b/root/foo.md
286 @@ -1,3 +1,3 @@
287 Hello!
288 -
289 +How
290 Bye
291 "}),
292 "{prompt}"
293 );
294
295 respond_tx
296 .send(model_response(
297 &request,
298 indoc! {r#"
299 --- a/root/foo.md
300 +++ b/root/foo.md
301 @@ ... @@
302 Hello!
303 -How
304 +How are you?
305 Bye
306 "#},
307 ))
308 .unwrap();
309
310 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
311
312 assert_eq!(prediction.edits.len(), 1);
313 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
314}
315
316#[gpui::test]
317async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
318 let (ep_store, _requests) = init_test_with_fake_client(cx);
319 let fs = FakeFs::new(cx.executor());
320 fs.insert_tree(
321 "/root",
322 json!({
323 "foo.md": "Hello!\n\nBye\n"
324 }),
325 )
326 .await;
327 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
328
329 let buffer = project
330 .update(cx, |project, cx| {
331 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
332 project.open_buffer(path, cx)
333 })
334 .await
335 .unwrap();
336
337 ep_store.update(cx, |ep_store, cx| {
338 ep_store.register_buffer(&buffer, &project, cx);
339 });
340
341 // First burst: insert "How"
342 buffer.update(cx, |buffer, cx| {
343 buffer.edit(vec![(7..7, "How")], None, cx);
344 });
345
346 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
347 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
348 cx.run_until_parked();
349
350 // Second burst: append " are you?" immediately after "How" on the same line.
351 //
352 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
353 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
354 buffer.update(cx, |buffer, cx| {
355 buffer.edit(vec![(10..10, " are you?")], None, cx);
356 });
357
358 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
359 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
360 buffer.update(cx, |buffer, cx| {
361 buffer.edit(vec![(19..19, "!")], None, cx);
362 });
363
364 // Without time-based splitting, there is one event.
365 let events = ep_store.update(cx, |ep_store, cx| {
366 ep_store.edit_history_for_project(&project, cx)
367 });
368 assert_eq!(events.len(), 1);
369 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
370 assert_eq!(
371 diff.as_str(),
372 indoc! {"
373 @@ -1,3 +1,3 @@
374 Hello!
375 -
376 +How are you?!
377 Bye
378 "}
379 );
380
381 // With time-based splitting, there are two distinct events.
382 let events = ep_store.update(cx, |ep_store, cx| {
383 ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
384 });
385 assert_eq!(events.len(), 2);
386 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
387 assert_eq!(
388 diff.as_str(),
389 indoc! {"
390 @@ -1,3 +1,3 @@
391 Hello!
392 -
393 +How
394 Bye
395 "}
396 );
397
398 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
399 assert_eq!(
400 diff.as_str(),
401 indoc! {"
402 @@ -1,3 +1,3 @@
403 Hello!
404 -How
405 +How are you?!
406 Bye
407 "}
408 );
409}
410
411#[gpui::test]
412async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
413 let (ep_store, _requests) = init_test_with_fake_client(cx);
414 let fs = FakeFs::new(cx.executor());
415
416 // Create a file with 30 lines to test line-based coalescing
417 let content = (1..=30)
418 .map(|i| format!("Line {}\n", i))
419 .collect::<String>();
420 fs.insert_tree(
421 "/root",
422 json!({
423 "foo.md": content
424 }),
425 )
426 .await;
427 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
428
429 let buffer = project
430 .update(cx, |project, cx| {
431 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
432 project.open_buffer(path, cx)
433 })
434 .await
435 .unwrap();
436
437 ep_store.update(cx, |ep_store, cx| {
438 ep_store.register_buffer(&buffer, &project, cx);
439 });
440
441 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
442 buffer.update(cx, |buffer, cx| {
443 let start = Point::new(10, 0).to_offset(buffer);
444 let end = Point::new(13, 0).to_offset(buffer);
445 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
446 });
447
448 let events = ep_store.update(cx, |ep_store, cx| {
449 ep_store.edit_history_for_project(&project, cx)
450 });
451 assert_eq!(
452 render_events(&events),
453 indoc! {"
454 @@ -8,9 +8,8 @@
455 Line 8
456 Line 9
457 Line 10
458 -Line 11
459 -Line 12
460 -Line 13
461 +Middle A
462 +Middle B
463 Line 14
464 Line 15
465 Line 16
466 "},
467 "After first edit"
468 );
469
470 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
471 // This tests that coalescing considers the START of the existing range
472 buffer.update(cx, |buffer, cx| {
473 let offset = Point::new(5, 0).to_offset(buffer);
474 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
475 });
476
477 let events = ep_store.update(cx, |ep_store, cx| {
478 ep_store.edit_history_for_project(&project, cx)
479 });
480 assert_eq!(
481 render_events(&events),
482 indoc! {"
483 @@ -3,14 +3,14 @@
484 Line 3
485 Line 4
486 Line 5
487 +Above
488 Line 6
489 Line 7
490 Line 8
491 Line 9
492 Line 10
493 -Line 11
494 -Line 12
495 -Line 13
496 +Middle A
497 +Middle B
498 Line 14
499 Line 15
500 Line 16
501 "},
502 "After inserting above (should coalesce)"
503 );
504
505 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
506 // This tests that coalescing considers the END of the existing range
507 buffer.update(cx, |buffer, cx| {
508 let offset = Point::new(14, 0).to_offset(buffer);
509 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
510 });
511
512 let events = ep_store.update(cx, |ep_store, cx| {
513 ep_store.edit_history_for_project(&project, cx)
514 });
515 assert_eq!(
516 render_events(&events),
517 indoc! {"
518 @@ -3,15 +3,16 @@
519 Line 3
520 Line 4
521 Line 5
522 +Above
523 Line 6
524 Line 7
525 Line 8
526 Line 9
527 Line 10
528 -Line 11
529 -Line 12
530 -Line 13
531 +Middle A
532 +Middle B
533 Line 14
534 +Below
535 Line 15
536 Line 16
537 Line 17
538 "},
539 "After inserting below (should coalesce)"
540 );
541
542 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
543 // This should NOT coalesce - creates a new event
544 buffer.update(cx, |buffer, cx| {
545 let offset = Point::new(25, 0).to_offset(buffer);
546 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
547 });
548
549 let events = ep_store.update(cx, |ep_store, cx| {
550 ep_store.edit_history_for_project(&project, cx)
551 });
552 assert_eq!(
553 render_events(&events),
554 indoc! {"
555 @@ -3,15 +3,16 @@
556 Line 3
557 Line 4
558 Line 5
559 +Above
560 Line 6
561 Line 7
562 Line 8
563 Line 9
564 Line 10
565 -Line 11
566 -Line 12
567 -Line 13
568 +Middle A
569 +Middle B
570 Line 14
571 +Below
572 Line 15
573 Line 16
574 Line 17
575
576 ---
577 @@ -23,6 +23,7 @@
578 Line 22
579 Line 23
580 Line 24
581 +Far below
582 Line 25
583 Line 26
584 Line 27
585 "},
586 "After inserting far below (should NOT coalesce)"
587 );
588}
589
590fn render_events(events: &[StoredEvent]) -> String {
591 events
592 .iter()
593 .map(|e| {
594 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
595 diff.as_str()
596 })
597 .collect::<Vec<_>>()
598 .join("\n---\n")
599}
600
601#[gpui::test]
602async fn test_empty_prediction(cx: &mut TestAppContext) {
603 let (ep_store, mut requests) = init_test_with_fake_client(cx);
604 let fs = FakeFs::new(cx.executor());
605 fs.insert_tree(
606 "/root",
607 json!({
608 "foo.md": "Hello!\nHow\nBye\n"
609 }),
610 )
611 .await;
612 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
613
614 let buffer = project
615 .update(cx, |project, cx| {
616 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
617 project.open_buffer(path, cx)
618 })
619 .await
620 .unwrap();
621 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
622 let position = snapshot.anchor_before(language::Point::new(1, 3));
623
624 ep_store.update(cx, |ep_store, cx| {
625 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
626 });
627
628 let (request, respond_tx) = requests.predict.next().await.unwrap();
629 let response = model_response(&request, "");
630 let id = response.request_id.clone();
631 respond_tx.send(response).unwrap();
632
633 cx.run_until_parked();
634
635 ep_store.update(cx, |ep_store, cx| {
636 assert!(
637 ep_store
638 .prediction_at(&buffer, None, &project, cx)
639 .is_none()
640 );
641 });
642
643 // prediction is reported as rejected
644 let (reject_request, _) = requests.reject.next().await.unwrap();
645
646 assert_eq!(
647 &reject_request.rejections,
648 &[EditPredictionRejection {
649 request_id: id,
650 reason: EditPredictionRejectReason::Empty,
651 was_shown: false
652 }]
653 );
654}
655
656#[gpui::test]
657async fn test_interpolated_empty(cx: &mut TestAppContext) {
658 let (ep_store, mut requests) = init_test_with_fake_client(cx);
659 let fs = FakeFs::new(cx.executor());
660 fs.insert_tree(
661 "/root",
662 json!({
663 "foo.md": "Hello!\nHow\nBye\n"
664 }),
665 )
666 .await;
667 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
668
669 let buffer = project
670 .update(cx, |project, cx| {
671 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
672 project.open_buffer(path, cx)
673 })
674 .await
675 .unwrap();
676 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
677 let position = snapshot.anchor_before(language::Point::new(1, 3));
678
679 ep_store.update(cx, |ep_store, cx| {
680 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
681 });
682
683 let (request, respond_tx) = requests.predict.next().await.unwrap();
684
685 buffer.update(cx, |buffer, cx| {
686 buffer.set_text("Hello!\nHow are you?\nBye", cx);
687 });
688
689 let response = model_response(&request, SIMPLE_DIFF);
690 let id = response.request_id.clone();
691 respond_tx.send(response).unwrap();
692
693 cx.run_until_parked();
694
695 ep_store.update(cx, |ep_store, cx| {
696 assert!(
697 ep_store
698 .prediction_at(&buffer, None, &project, cx)
699 .is_none()
700 );
701 });
702
703 // prediction is reported as rejected
704 let (reject_request, _) = requests.reject.next().await.unwrap();
705
706 assert_eq!(
707 &reject_request.rejections,
708 &[EditPredictionRejection {
709 request_id: id,
710 reason: EditPredictionRejectReason::InterpolatedEmpty,
711 was_shown: false
712 }]
713 );
714}
715
716const SIMPLE_DIFF: &str = indoc! { r"
717 --- a/root/foo.md
718 +++ b/root/foo.md
719 @@ ... @@
720 Hello!
721 -How
722 +How are you?
723 Bye
724"};
725
726#[gpui::test]
727async fn test_replace_current(cx: &mut TestAppContext) {
728 let (ep_store, mut requests) = init_test_with_fake_client(cx);
729 let fs = FakeFs::new(cx.executor());
730 fs.insert_tree(
731 "/root",
732 json!({
733 "foo.md": "Hello!\nHow\nBye\n"
734 }),
735 )
736 .await;
737 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
738
739 let buffer = project
740 .update(cx, |project, cx| {
741 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
742 project.open_buffer(path, cx)
743 })
744 .await
745 .unwrap();
746 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
747 let position = snapshot.anchor_before(language::Point::new(1, 3));
748
749 ep_store.update(cx, |ep_store, cx| {
750 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
751 });
752
753 let (request, respond_tx) = requests.predict.next().await.unwrap();
754 let first_response = model_response(&request, SIMPLE_DIFF);
755 let first_id = first_response.request_id.clone();
756 respond_tx.send(first_response).unwrap();
757
758 cx.run_until_parked();
759
760 ep_store.update(cx, |ep_store, cx| {
761 assert_eq!(
762 ep_store
763 .prediction_at(&buffer, None, &project, cx)
764 .unwrap()
765 .id
766 .0,
767 first_id
768 );
769 });
770
771 // a second request is triggered
772 ep_store.update(cx, |ep_store, cx| {
773 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
774 });
775
776 let (request, respond_tx) = requests.predict.next().await.unwrap();
777 let second_response = model_response(&request, SIMPLE_DIFF);
778 let second_id = second_response.request_id.clone();
779 respond_tx.send(second_response).unwrap();
780
781 cx.run_until_parked();
782
783 ep_store.update(cx, |ep_store, cx| {
784 // second replaces first
785 assert_eq!(
786 ep_store
787 .prediction_at(&buffer, None, &project, cx)
788 .unwrap()
789 .id
790 .0,
791 second_id
792 );
793 });
794
795 // first is reported as replaced
796 let (reject_request, _) = requests.reject.next().await.unwrap();
797
798 assert_eq!(
799 &reject_request.rejections,
800 &[EditPredictionRejection {
801 request_id: first_id,
802 reason: EditPredictionRejectReason::Replaced,
803 was_shown: false
804 }]
805 );
806}
807
808#[gpui::test]
809async fn test_current_preferred(cx: &mut TestAppContext) {
810 let (ep_store, mut requests) = init_test_with_fake_client(cx);
811 let fs = FakeFs::new(cx.executor());
812 fs.insert_tree(
813 "/root",
814 json!({
815 "foo.md": "Hello!\nHow\nBye\n"
816 }),
817 )
818 .await;
819 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
820
821 let buffer = project
822 .update(cx, |project, cx| {
823 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
824 project.open_buffer(path, cx)
825 })
826 .await
827 .unwrap();
828 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
829 let position = snapshot.anchor_before(language::Point::new(1, 3));
830
831 ep_store.update(cx, |ep_store, cx| {
832 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
833 });
834
835 let (request, respond_tx) = requests.predict.next().await.unwrap();
836 let first_response = model_response(&request, SIMPLE_DIFF);
837 let first_id = first_response.request_id.clone();
838 respond_tx.send(first_response).unwrap();
839
840 cx.run_until_parked();
841
842 ep_store.update(cx, |ep_store, cx| {
843 assert_eq!(
844 ep_store
845 .prediction_at(&buffer, None, &project, cx)
846 .unwrap()
847 .id
848 .0,
849 first_id
850 );
851 });
852
853 // a second request is triggered
854 ep_store.update(cx, |ep_store, cx| {
855 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
856 });
857
858 let (request, respond_tx) = requests.predict.next().await.unwrap();
859 // worse than current prediction
860 let second_response = model_response(
861 &request,
862 indoc! { r"
863 --- a/root/foo.md
864 +++ b/root/foo.md
865 @@ ... @@
866 Hello!
867 -How
868 +How are
869 Bye
870 "},
871 );
872 let second_id = second_response.request_id.clone();
873 respond_tx.send(second_response).unwrap();
874
875 cx.run_until_parked();
876
877 ep_store.update(cx, |ep_store, cx| {
878 // first is preferred over second
879 assert_eq!(
880 ep_store
881 .prediction_at(&buffer, None, &project, cx)
882 .unwrap()
883 .id
884 .0,
885 first_id
886 );
887 });
888
889 // second is reported as rejected
890 let (reject_request, _) = requests.reject.next().await.unwrap();
891
892 assert_eq!(
893 &reject_request.rejections,
894 &[EditPredictionRejection {
895 request_id: second_id,
896 reason: EditPredictionRejectReason::CurrentPreferred,
897 was_shown: false
898 }]
899 );
900}
901
902#[gpui::test]
903async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
904 let (ep_store, mut requests) = init_test_with_fake_client(cx);
905 let fs = FakeFs::new(cx.executor());
906 fs.insert_tree(
907 "/root",
908 json!({
909 "foo.md": "Hello!\nHow\nBye\n"
910 }),
911 )
912 .await;
913 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
914
915 let buffer = project
916 .update(cx, |project, cx| {
917 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
918 project.open_buffer(path, cx)
919 })
920 .await
921 .unwrap();
922 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
923 let position = snapshot.anchor_before(language::Point::new(1, 3));
924
925 // start two refresh tasks
926 ep_store.update(cx, |ep_store, cx| {
927 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
928 });
929
930 let (request1, respond_first) = requests.predict.next().await.unwrap();
931
932 ep_store.update(cx, |ep_store, cx| {
933 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
934 });
935
936 let (request, respond_second) = requests.predict.next().await.unwrap();
937
938 // wait for throttle
939 cx.run_until_parked();
940
941 // second responds first
942 let second_response = model_response(&request, SIMPLE_DIFF);
943 let second_id = second_response.request_id.clone();
944 respond_second.send(second_response).unwrap();
945
946 cx.run_until_parked();
947
948 ep_store.update(cx, |ep_store, cx| {
949 // current prediction is second
950 assert_eq!(
951 ep_store
952 .prediction_at(&buffer, None, &project, cx)
953 .unwrap()
954 .id
955 .0,
956 second_id
957 );
958 });
959
960 let first_response = model_response(&request1, SIMPLE_DIFF);
961 let first_id = first_response.request_id.clone();
962 respond_first.send(first_response).unwrap();
963
964 cx.run_until_parked();
965
966 ep_store.update(cx, |ep_store, cx| {
967 // current prediction is still second, since first was cancelled
968 assert_eq!(
969 ep_store
970 .prediction_at(&buffer, None, &project, cx)
971 .unwrap()
972 .id
973 .0,
974 second_id
975 );
976 });
977
978 // first is reported as rejected
979 let (reject_request, _) = requests.reject.next().await.unwrap();
980
981 cx.run_until_parked();
982
983 assert_eq!(
984 &reject_request.rejections,
985 &[EditPredictionRejection {
986 request_id: first_id,
987 reason: EditPredictionRejectReason::Canceled,
988 was_shown: false
989 }]
990 );
991}
992
993#[gpui::test]
994async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
995 let (ep_store, mut requests) = init_test_with_fake_client(cx);
996 let fs = FakeFs::new(cx.executor());
997 fs.insert_tree(
998 "/root",
999 json!({
1000 "foo.md": "Hello!\nHow\nBye\n"
1001 }),
1002 )
1003 .await;
1004 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1005
1006 let buffer = project
1007 .update(cx, |project, cx| {
1008 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1009 project.open_buffer(path, cx)
1010 })
1011 .await
1012 .unwrap();
1013 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1014 let position = snapshot.anchor_before(language::Point::new(1, 3));
1015
1016 // start two refresh tasks
1017 ep_store.update(cx, |ep_store, cx| {
1018 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1019 });
1020
1021 let (request1, respond_first) = requests.predict.next().await.unwrap();
1022
1023 ep_store.update(cx, |ep_store, cx| {
1024 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1025 });
1026
1027 let (request2, respond_second) = requests.predict.next().await.unwrap();
1028
1029 // wait for throttle, so requests are sent
1030 cx.run_until_parked();
1031
1032 ep_store.update(cx, |ep_store, cx| {
1033 // start a third request
1034 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1035
1036 // 2 are pending, so 2nd is cancelled
1037 assert_eq!(
1038 ep_store
1039 .get_or_init_project(&project, cx)
1040 .cancelled_predictions
1041 .iter()
1042 .copied()
1043 .collect::<Vec<_>>(),
1044 [1]
1045 );
1046 });
1047
1048 // wait for throttle
1049 cx.run_until_parked();
1050
1051 let (request3, respond_third) = requests.predict.next().await.unwrap();
1052
1053 let first_response = model_response(&request1, SIMPLE_DIFF);
1054 let first_id = first_response.request_id.clone();
1055 respond_first.send(first_response).unwrap();
1056
1057 cx.run_until_parked();
1058
1059 ep_store.update(cx, |ep_store, cx| {
1060 // current prediction is first
1061 assert_eq!(
1062 ep_store
1063 .prediction_at(&buffer, None, &project, cx)
1064 .unwrap()
1065 .id
1066 .0,
1067 first_id
1068 );
1069 });
1070
1071 let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1072 let cancelled_id = cancelled_response.request_id.clone();
1073 respond_second.send(cancelled_response).unwrap();
1074
1075 cx.run_until_parked();
1076
1077 ep_store.update(cx, |ep_store, cx| {
1078 // current prediction is still first, since second was cancelled
1079 assert_eq!(
1080 ep_store
1081 .prediction_at(&buffer, None, &project, cx)
1082 .unwrap()
1083 .id
1084 .0,
1085 first_id
1086 );
1087 });
1088
1089 let third_response = model_response(&request3, SIMPLE_DIFF);
1090 let third_response_id = third_response.request_id.clone();
1091 respond_third.send(third_response).unwrap();
1092
1093 cx.run_until_parked();
1094
1095 ep_store.update(cx, |ep_store, cx| {
1096 // third completes and replaces first
1097 assert_eq!(
1098 ep_store
1099 .prediction_at(&buffer, None, &project, cx)
1100 .unwrap()
1101 .id
1102 .0,
1103 third_response_id
1104 );
1105 });
1106
1107 // second is reported as rejected
1108 let (reject_request, _) = requests.reject.next().await.unwrap();
1109
1110 cx.run_until_parked();
1111
1112 assert_eq!(
1113 &reject_request.rejections,
1114 &[
1115 EditPredictionRejection {
1116 request_id: cancelled_id,
1117 reason: EditPredictionRejectReason::Canceled,
1118 was_shown: false
1119 },
1120 EditPredictionRejection {
1121 request_id: first_id,
1122 reason: EditPredictionRejectReason::Replaced,
1123 was_shown: false
1124 }
1125 ]
1126 );
1127}
1128
1129#[gpui::test]
1130async fn test_rejections_flushing(cx: &mut TestAppContext) {
1131 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1132
1133 ep_store.update(cx, |ep_store, cx| {
1134 ep_store.reject_prediction(
1135 EditPredictionId("test-1".into()),
1136 EditPredictionRejectReason::Discarded,
1137 false,
1138 edit_prediction_types::EditPredictionDismissReason::Ignored,
1139 cx,
1140 );
1141 ep_store.reject_prediction(
1142 EditPredictionId("test-2".into()),
1143 EditPredictionRejectReason::Canceled,
1144 true,
1145 edit_prediction_types::EditPredictionDismissReason::Ignored,
1146 cx,
1147 );
1148 });
1149
1150 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1151 cx.run_until_parked();
1152
1153 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1154 respond_tx.send(()).unwrap();
1155
1156 // batched
1157 assert_eq!(reject_request.rejections.len(), 2);
1158 assert_eq!(
1159 reject_request.rejections[0],
1160 EditPredictionRejection {
1161 request_id: "test-1".to_string(),
1162 reason: EditPredictionRejectReason::Discarded,
1163 was_shown: false
1164 }
1165 );
1166 assert_eq!(
1167 reject_request.rejections[1],
1168 EditPredictionRejection {
1169 request_id: "test-2".to_string(),
1170 reason: EditPredictionRejectReason::Canceled,
1171 was_shown: true
1172 }
1173 );
1174
1175 // Reaching batch size limit sends without debounce
1176 ep_store.update(cx, |ep_store, cx| {
1177 for i in 0..70 {
1178 ep_store.reject_prediction(
1179 EditPredictionId(format!("batch-{}", i).into()),
1180 EditPredictionRejectReason::Discarded,
1181 false,
1182 edit_prediction_types::EditPredictionDismissReason::Ignored,
1183 cx,
1184 );
1185 }
1186 });
1187
1188 // First MAX/2 items are sent immediately
1189 cx.run_until_parked();
1190 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1191 respond_tx.send(()).unwrap();
1192
1193 assert_eq!(reject_request.rejections.len(), 50);
1194 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1195 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1196
1197 // Remaining items are debounced with the next batch
1198 cx.executor().advance_clock(Duration::from_secs(15));
1199 cx.run_until_parked();
1200
1201 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1202 respond_tx.send(()).unwrap();
1203
1204 assert_eq!(reject_request.rejections.len(), 20);
1205 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1206 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1207
1208 // Request failure
1209 ep_store.update(cx, |ep_store, cx| {
1210 ep_store.reject_prediction(
1211 EditPredictionId("retry-1".into()),
1212 EditPredictionRejectReason::Discarded,
1213 false,
1214 edit_prediction_types::EditPredictionDismissReason::Ignored,
1215 cx,
1216 );
1217 });
1218
1219 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1220 cx.run_until_parked();
1221
1222 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1223 assert_eq!(reject_request.rejections.len(), 1);
1224 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1225 // Simulate failure
1226 drop(_respond_tx);
1227
1228 // Add another rejection
1229 ep_store.update(cx, |ep_store, cx| {
1230 ep_store.reject_prediction(
1231 EditPredictionId("retry-2".into()),
1232 EditPredictionRejectReason::Discarded,
1233 false,
1234 edit_prediction_types::EditPredictionDismissReason::Ignored,
1235 cx,
1236 );
1237 });
1238
1239 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1240 cx.run_until_parked();
1241
1242 // Retry should include both the failed item and the new one
1243 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1244 respond_tx.send(()).unwrap();
1245
1246 assert_eq!(reject_request.rejections.len(), 2);
1247 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1248 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1249}
1250
1251// Skipped until we start including diagnostics in prompt
1252// #[gpui::test]
1253// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1254// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1255// let fs = FakeFs::new(cx.executor());
1256// fs.insert_tree(
1257// "/root",
1258// json!({
1259// "foo.md": "Hello!\nBye"
1260// }),
1261// )
1262// .await;
1263// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1264
1265// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1266// let diagnostic = lsp::Diagnostic {
1267// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1268// severity: Some(lsp::DiagnosticSeverity::ERROR),
1269// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1270// ..Default::default()
1271// };
1272
1273// project.update(cx, |project, cx| {
1274// project.lsp_store().update(cx, |lsp_store, cx| {
1275// // Create some diagnostics
1276// lsp_store
1277// .update_diagnostics(
1278// LanguageServerId(0),
1279// lsp::PublishDiagnosticsParams {
1280// uri: path_to_buffer_uri.clone(),
1281// diagnostics: vec![diagnostic],
1282// version: None,
1283// },
1284// None,
1285// language::DiagnosticSourceKind::Pushed,
1286// &[],
1287// cx,
1288// )
1289// .unwrap();
1290// });
1291// });
1292
1293// let buffer = project
1294// .update(cx, |project, cx| {
1295// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1296// project.open_buffer(path, cx)
1297// })
1298// .await
1299// .unwrap();
1300
1301// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1302// let position = snapshot.anchor_before(language::Point::new(0, 0));
1303
1304// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1305// ep_store.request_prediction(&project, &buffer, position, cx)
1306// });
1307
1308// let (request, _respond_tx) = req_rx.next().await.unwrap();
1309
1310// assert_eq!(request.diagnostic_groups.len(), 1);
1311// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1312// .unwrap();
1313// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1314// assert_eq!(
1315// value,
1316// json!({
1317// "entries": [{
1318// "range": {
1319// "start": 8,
1320// "end": 10
1321// },
1322// "diagnostic": {
1323// "source": null,
1324// "code": null,
1325// "code_description": null,
1326// "severity": 1,
1327// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1328// "markdown": null,
1329// "group_id": 0,
1330// "is_primary": true,
1331// "is_disk_based": false,
1332// "is_unnecessary": false,
1333// "source_kind": "Pushed",
1334// "data": null,
1335// "underline": true
1336// }
1337// }],
1338// "primary_ix": 0
1339// })
1340// );
1341// }
1342
1343// Generate a model response that would apply the given diff to the active file.
1344fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1345 let excerpt =
1346 request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1347 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1348
1349 PredictEditsV3Response {
1350 request_id: Uuid::new_v4().to_string(),
1351 output: new_excerpt,
1352 }
1353}
1354
1355fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1356 zeta_prompt::format_zeta_prompt(&request.input, request.prompt_version)
1357}
1358
1359struct RequestChannels {
1360 predict: mpsc::UnboundedReceiver<(
1361 PredictEditsV3Request,
1362 oneshot::Sender<PredictEditsV3Response>,
1363 )>,
1364 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1365}
1366
1367fn init_test_with_fake_client(
1368 cx: &mut TestAppContext,
1369) -> (Entity<EditPredictionStore>, RequestChannels) {
1370 cx.update(move |cx| {
1371 let settings_store = SettingsStore::test(cx);
1372 cx.set_global(settings_store);
1373 zlog::init_test();
1374
1375 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1376 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1377
1378 let http_client = FakeHttpClient::create({
1379 move |req| {
1380 let uri = req.uri().path().to_string();
1381 let mut body = req.into_body();
1382 let predict_req_tx = predict_req_tx.clone();
1383 let reject_req_tx = reject_req_tx.clone();
1384 async move {
1385 let resp = match uri.as_str() {
1386 "/client/llm_tokens" => serde_json::to_string(&json!({
1387 "token": "test"
1388 }))
1389 .unwrap(),
1390 "/predict_edits/v3" => {
1391 let mut buf = Vec::new();
1392 body.read_to_end(&mut buf).await.ok();
1393 let req = serde_json::from_slice(&buf).unwrap();
1394
1395 let (res_tx, res_rx) = oneshot::channel();
1396 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1397 serde_json::to_string(&res_rx.await?).unwrap()
1398 }
1399 "/predict_edits/reject" => {
1400 let mut buf = Vec::new();
1401 body.read_to_end(&mut buf).await.ok();
1402 let req = serde_json::from_slice(&buf).unwrap();
1403
1404 let (res_tx, res_rx) = oneshot::channel();
1405 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1406 serde_json::to_string(&res_rx.await?).unwrap()
1407 }
1408 _ => {
1409 panic!("Unexpected path: {}", uri)
1410 }
1411 };
1412
1413 Ok(Response::builder().body(resp.into()).unwrap())
1414 }
1415 }
1416 });
1417
1418 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1419 client.cloud_client().set_credentials(1, "test".into());
1420
1421 language_model::init(client.clone(), cx);
1422
1423 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1424 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1425
1426 (
1427 ep_store,
1428 RequestChannels {
1429 predict: predict_req_rx,
1430 reject: reject_req_rx,
1431 },
1432 )
1433 })
1434}
1435
1436const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1437
1438#[gpui::test]
1439async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1440 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1441 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1442 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1443 });
1444
1445 let edit_preview = cx
1446 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1447 .await;
1448
1449 let prediction = EditPrediction {
1450 edits,
1451 cursor_position: None,
1452 edit_preview,
1453 buffer: buffer.clone(),
1454 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1455 id: EditPredictionId("the-id".into()),
1456 inputs: ZetaPromptInput {
1457 events: Default::default(),
1458 related_files: Default::default(),
1459 cursor_path: Path::new("").into(),
1460 cursor_excerpt: "".into(),
1461 editable_range_in_excerpt: 0..0,
1462 cursor_offset_in_excerpt: 0,
1463 excerpt_start_row: None,
1464 },
1465 buffer_snapshotted_at: Instant::now(),
1466 response_received_at: Instant::now(),
1467 };
1468
1469 cx.update(|cx| {
1470 assert_eq!(
1471 from_completion_edits(
1472 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1473 &buffer,
1474 cx
1475 ),
1476 vec![(2..5, "REM".into()), (9..11, "".into())]
1477 );
1478
1479 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1480 assert_eq!(
1481 from_completion_edits(
1482 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1483 &buffer,
1484 cx
1485 ),
1486 vec![(2..2, "REM".into()), (6..8, "".into())]
1487 );
1488
1489 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1490 assert_eq!(
1491 from_completion_edits(
1492 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1493 &buffer,
1494 cx
1495 ),
1496 vec![(2..5, "REM".into()), (9..11, "".into())]
1497 );
1498
1499 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1500 assert_eq!(
1501 from_completion_edits(
1502 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1503 &buffer,
1504 cx
1505 ),
1506 vec![(3..3, "EM".into()), (7..9, "".into())]
1507 );
1508
1509 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1510 assert_eq!(
1511 from_completion_edits(
1512 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1513 &buffer,
1514 cx
1515 ),
1516 vec![(4..4, "M".into()), (8..10, "".into())]
1517 );
1518
1519 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1520 assert_eq!(
1521 from_completion_edits(
1522 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1523 &buffer,
1524 cx
1525 ),
1526 vec![(9..11, "".into())]
1527 );
1528
1529 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1530 assert_eq!(
1531 from_completion_edits(
1532 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1533 &buffer,
1534 cx
1535 ),
1536 vec![(4..4, "M".into()), (8..10, "".into())]
1537 );
1538
1539 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1540 assert_eq!(
1541 from_completion_edits(
1542 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1543 &buffer,
1544 cx
1545 ),
1546 vec![(4..4, "M".into())]
1547 );
1548
1549 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1550 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1551 })
1552}
1553
1554#[gpui::test]
1555async fn test_clean_up_diff(cx: &mut TestAppContext) {
1556 init_test(cx);
1557
1558 assert_eq!(
1559 apply_edit_prediction(
1560 indoc! {"
1561 fn main() {
1562 let word_1 = \"lorem\";
1563 let range = word.len()..word.len();
1564 }
1565 "},
1566 indoc! {"
1567 <|editable_region_start|>
1568 fn main() {
1569 let word_1 = \"lorem\";
1570 let range = word_1.len()..word_1.len();
1571 }
1572
1573 <|editable_region_end|>
1574 "},
1575 cx,
1576 )
1577 .await,
1578 indoc! {"
1579 fn main() {
1580 let word_1 = \"lorem\";
1581 let range = word_1.len()..word_1.len();
1582 }
1583 "},
1584 );
1585
1586 assert_eq!(
1587 apply_edit_prediction(
1588 indoc! {"
1589 fn main() {
1590 let story = \"the quick\"
1591 }
1592 "},
1593 indoc! {"
1594 <|editable_region_start|>
1595 fn main() {
1596 let story = \"the quick brown fox jumps over the lazy dog\";
1597 }
1598
1599 <|editable_region_end|>
1600 "},
1601 cx,
1602 )
1603 .await,
1604 indoc! {"
1605 fn main() {
1606 let story = \"the quick brown fox jumps over the lazy dog\";
1607 }
1608 "},
1609 );
1610}
1611
1612#[gpui::test]
1613async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1614 init_test(cx);
1615
1616 let buffer_content = "lorem\n";
1617 let completion_response = indoc! {"
1618 ```animals.js
1619 <|start_of_file|>
1620 <|editable_region_start|>
1621 lorem
1622 ipsum
1623 <|editable_region_end|>
1624 ```"};
1625
1626 assert_eq!(
1627 apply_edit_prediction(buffer_content, completion_response, cx).await,
1628 "lorem\nipsum"
1629 );
1630}
1631
1632#[gpui::test]
1633async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1634 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1635 // When the buffer ends without a trailing newline, but the model returns output
1636 // with a trailing newline, zeta2 should normalize both sides before diffing
1637 // so no spurious newline is inserted.
1638 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1639 let fs = FakeFs::new(cx.executor());
1640
1641 // Single line buffer with no trailing newline
1642 fs.insert_tree(
1643 "/root",
1644 json!({
1645 "foo.txt": "hello"
1646 }),
1647 )
1648 .await;
1649 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1650
1651 let buffer = project
1652 .update(cx, |project, cx| {
1653 let path = project
1654 .find_project_path(path!("root/foo.txt"), cx)
1655 .unwrap();
1656 project.open_buffer(path, cx)
1657 })
1658 .await
1659 .unwrap();
1660
1661 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1662 let position = snapshot.anchor_before(language::Point::new(0, 5));
1663
1664 ep_store.update(cx, |ep_store, cx| {
1665 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1666 });
1667
1668 let (_request, respond_tx) = requests.predict.next().await.unwrap();
1669
1670 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1671 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1672 let response = PredictEditsV3Response {
1673 request_id: Uuid::new_v4().to_string(),
1674 output: "hello world\n".to_string(),
1675 };
1676 respond_tx.send(response).unwrap();
1677
1678 cx.run_until_parked();
1679
1680 // The prediction should insert " world" without adding a newline
1681 ep_store.update(cx, |ep_store, cx| {
1682 let prediction = ep_store
1683 .prediction_at(&buffer, None, &project, cx)
1684 .expect("should have prediction");
1685 let edits: Vec<_> = prediction
1686 .edits
1687 .iter()
1688 .map(|(range, text)| {
1689 let snapshot = buffer.read(cx).snapshot();
1690 (range.to_offset(&snapshot), text.clone())
1691 })
1692 .collect();
1693 assert_eq!(edits, vec![(5..5, " world".into())]);
1694 });
1695}
1696
1697#[gpui::test]
1698async fn test_can_collect_data(cx: &mut TestAppContext) {
1699 init_test(cx);
1700
1701 let fs = project::FakeFs::new(cx.executor());
1702 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1703 .await;
1704
1705 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1706 let buffer = project
1707 .update(cx, |project, cx| {
1708 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1709 })
1710 .await
1711 .unwrap();
1712
1713 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1714 ep_store.update(cx, |ep_store, _cx| {
1715 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1716 });
1717
1718 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1719 assert_eq!(
1720 captured_request.lock().clone().unwrap().can_collect_data,
1721 true
1722 );
1723
1724 ep_store.update(cx, |ep_store, _cx| {
1725 ep_store.data_collection_choice = DataCollectionChoice::Disabled
1726 });
1727
1728 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1729 assert_eq!(
1730 captured_request.lock().clone().unwrap().can_collect_data,
1731 false
1732 );
1733}
1734
1735#[gpui::test]
1736async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1737 init_test(cx);
1738
1739 let fs = project::FakeFs::new(cx.executor());
1740 let project = Project::test(fs.clone(), [], cx).await;
1741
1742 let buffer = cx.new(|_cx| {
1743 Buffer::remote(
1744 language::BufferId::new(1).unwrap(),
1745 ReplicaId::new(1),
1746 language::Capability::ReadWrite,
1747 "fn main() {\n println!(\"Hello\");\n}",
1748 )
1749 });
1750
1751 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1752 ep_store.update(cx, |ep_store, _cx| {
1753 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1754 });
1755
1756 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1757 assert_eq!(
1758 captured_request.lock().clone().unwrap().can_collect_data,
1759 false
1760 );
1761}
1762
1763#[gpui::test]
1764async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1765 init_test(cx);
1766
1767 let fs = project::FakeFs::new(cx.executor());
1768 fs.insert_tree(
1769 path!("/project"),
1770 json!({
1771 "LICENSE": BSD_0_TXT,
1772 ".env": "SECRET_KEY=secret"
1773 }),
1774 )
1775 .await;
1776
1777 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1778 let buffer = project
1779 .update(cx, |project, cx| {
1780 project.open_local_buffer("/project/.env", cx)
1781 })
1782 .await
1783 .unwrap();
1784
1785 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1786 ep_store.update(cx, |ep_store, _cx| {
1787 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1788 });
1789
1790 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1791 assert_eq!(
1792 captured_request.lock().clone().unwrap().can_collect_data,
1793 false
1794 );
1795}
1796
1797#[gpui::test]
1798async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1799 init_test(cx);
1800
1801 let fs = project::FakeFs::new(cx.executor());
1802 let project = Project::test(fs.clone(), [], cx).await;
1803 let buffer = cx.new(|cx| Buffer::local("", cx));
1804
1805 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1806 ep_store.update(cx, |ep_store, _cx| {
1807 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1808 });
1809
1810 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1811 assert_eq!(
1812 captured_request.lock().clone().unwrap().can_collect_data,
1813 false
1814 );
1815}
1816
1817#[gpui::test]
1818async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1819 init_test(cx);
1820
1821 let fs = project::FakeFs::new(cx.executor());
1822 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1823 .await;
1824
1825 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1826 let buffer = project
1827 .update(cx, |project, cx| {
1828 project.open_local_buffer("/project/main.rs", cx)
1829 })
1830 .await
1831 .unwrap();
1832
1833 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1834 ep_store.update(cx, |ep_store, _cx| {
1835 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1836 });
1837
1838 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1839 assert_eq!(
1840 captured_request.lock().clone().unwrap().can_collect_data,
1841 false
1842 );
1843}
1844
1845#[gpui::test]
1846async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1847 init_test(cx);
1848
1849 let fs = project::FakeFs::new(cx.executor());
1850 fs.insert_tree(
1851 path!("/open_source_worktree"),
1852 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1853 )
1854 .await;
1855 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1856 .await;
1857
1858 let project = Project::test(
1859 fs.clone(),
1860 [
1861 path!("/open_source_worktree").as_ref(),
1862 path!("/closed_source_worktree").as_ref(),
1863 ],
1864 cx,
1865 )
1866 .await;
1867 let buffer = project
1868 .update(cx, |project, cx| {
1869 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1870 })
1871 .await
1872 .unwrap();
1873
1874 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1875 ep_store.update(cx, |ep_store, _cx| {
1876 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1877 });
1878
1879 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1880 assert_eq!(
1881 captured_request.lock().clone().unwrap().can_collect_data,
1882 true
1883 );
1884
1885 let closed_source_file = project
1886 .update(cx, |project, cx| {
1887 let worktree2 = project
1888 .worktree_for_root_name("closed_source_worktree", cx)
1889 .unwrap();
1890 worktree2.update(cx, |worktree2, cx| {
1891 worktree2.load_file(rel_path("main.rs"), cx)
1892 })
1893 })
1894 .await
1895 .unwrap()
1896 .file;
1897
1898 buffer.update(cx, |buffer, cx| {
1899 buffer.file_updated(closed_source_file, cx);
1900 });
1901
1902 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1903 assert_eq!(
1904 captured_request.lock().clone().unwrap().can_collect_data,
1905 false
1906 );
1907}
1908
1909#[gpui::test]
1910async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1911 init_test(cx);
1912
1913 let fs = project::FakeFs::new(cx.executor());
1914 fs.insert_tree(
1915 path!("/worktree1"),
1916 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1917 )
1918 .await;
1919 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1920 .await;
1921
1922 let project = Project::test(
1923 fs.clone(),
1924 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1925 cx,
1926 )
1927 .await;
1928 let buffer = project
1929 .update(cx, |project, cx| {
1930 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1931 })
1932 .await
1933 .unwrap();
1934 let private_buffer = project
1935 .update(cx, |project, cx| {
1936 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1937 })
1938 .await
1939 .unwrap();
1940
1941 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1942 ep_store.update(cx, |ep_store, _cx| {
1943 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1944 });
1945
1946 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1947 assert_eq!(
1948 captured_request.lock().clone().unwrap().can_collect_data,
1949 true
1950 );
1951
1952 // this has a side effect of registering the buffer to watch for edits
1953 run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1954 assert_eq!(
1955 captured_request.lock().clone().unwrap().can_collect_data,
1956 false
1957 );
1958
1959 private_buffer.update(cx, |private_buffer, cx| {
1960 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1961 });
1962
1963 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1964 assert_eq!(
1965 captured_request.lock().clone().unwrap().can_collect_data,
1966 false
1967 );
1968
1969 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1970 // included
1971 buffer.update(cx, |buffer, cx| {
1972 buffer.edit(
1973 [(
1974 0..0,
1975 " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1976 )],
1977 None,
1978 cx,
1979 );
1980 });
1981
1982 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1983 assert_eq!(
1984 captured_request.lock().clone().unwrap().can_collect_data,
1985 true
1986 );
1987}
1988
1989fn init_test(cx: &mut TestAppContext) {
1990 cx.update(|cx| {
1991 let settings_store = SettingsStore::test(cx);
1992 cx.set_global(settings_store);
1993 });
1994}
1995
1996async fn apply_edit_prediction(
1997 buffer_content: &str,
1998 completion_response: &str,
1999 cx: &mut TestAppContext,
2000) -> String {
2001 let fs = project::FakeFs::new(cx.executor());
2002 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2003 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2004 let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
2005 *response.lock() = completion_response.to_string();
2006 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2007 buffer.update(cx, |buffer, cx| {
2008 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2009 });
2010 buffer.read_with(cx, |buffer, _| buffer.text())
2011}
2012
2013async fn run_edit_prediction(
2014 buffer: &Entity<Buffer>,
2015 project: &Entity<Project>,
2016 ep_store: &Entity<EditPredictionStore>,
2017 cx: &mut TestAppContext,
2018) -> EditPrediction {
2019 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2020 ep_store.update(cx, |ep_store, cx| {
2021 ep_store.register_buffer(buffer, &project, cx)
2022 });
2023 cx.background_executor.run_until_parked();
2024 let prediction_task = ep_store.update(cx, |ep_store, cx| {
2025 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2026 });
2027 prediction_task.await.unwrap().unwrap().prediction.unwrap()
2028}
2029
2030async fn make_test_ep_store(
2031 project: &Entity<Project>,
2032 cx: &mut TestAppContext,
2033) -> (
2034 Entity<EditPredictionStore>,
2035 Arc<Mutex<Option<PredictEditsBody>>>,
2036 Arc<Mutex<String>>,
2037) {
2038 let default_response = indoc! {"
2039 ```main.rs
2040 <|start_of_file|>
2041 <|editable_region_start|>
2042 hello world
2043 <|editable_region_end|>
2044 ```"
2045 };
2046 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2047 let completion_response: Arc<Mutex<String>> =
2048 Arc::new(Mutex::new(default_response.to_string()));
2049 let http_client = FakeHttpClient::create({
2050 let captured_request = captured_request.clone();
2051 let completion_response = completion_response.clone();
2052 let mut next_request_id = 0;
2053 move |req| {
2054 let captured_request = captured_request.clone();
2055 let completion_response = completion_response.clone();
2056 async move {
2057 match (req.method(), req.uri().path()) {
2058 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2059 .status(200)
2060 .body(
2061 serde_json::to_string(&CreateLlmTokenResponse {
2062 token: LlmToken("the-llm-token".to_string()),
2063 })
2064 .unwrap()
2065 .into(),
2066 )
2067 .unwrap()),
2068 (&Method::POST, "/predict_edits/v2") => {
2069 let mut request_body = String::new();
2070 req.into_body().read_to_string(&mut request_body).await?;
2071 *captured_request.lock() =
2072 Some(serde_json::from_str(&request_body).unwrap());
2073 next_request_id += 1;
2074 Ok(http_client::Response::builder()
2075 .status(200)
2076 .body(
2077 serde_json::to_string(&PredictEditsResponse {
2078 request_id: format!("request-{next_request_id}"),
2079 output_excerpt: completion_response.lock().clone(),
2080 })
2081 .unwrap()
2082 .into(),
2083 )
2084 .unwrap())
2085 }
2086 _ => Ok(http_client::Response::builder()
2087 .status(404)
2088 .body("Not Found".into())
2089 .unwrap()),
2090 }
2091 }
2092 }
2093 });
2094
2095 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2096 cx.update(|cx| {
2097 RefreshLlmTokenListener::register(client.clone(), cx);
2098 });
2099 let _server = FakeServer::for_client(42, &client, cx).await;
2100
2101 let ep_store = cx.new(|cx| {
2102 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2103 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2104
2105 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2106 for worktree in worktrees {
2107 let worktree_id = worktree.read(cx).id();
2108 ep_store
2109 .get_or_init_project(project, cx)
2110 .license_detection_watchers
2111 .entry(worktree_id)
2112 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2113 }
2114
2115 ep_store
2116 });
2117
2118 (ep_store, captured_request, completion_response)
2119}
2120
2121fn to_completion_edits(
2122 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2123 buffer: &Entity<Buffer>,
2124 cx: &App,
2125) -> Vec<(Range<Anchor>, Arc<str>)> {
2126 let buffer = buffer.read(cx);
2127 iterator
2128 .into_iter()
2129 .map(|(range, text)| {
2130 (
2131 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2132 text,
2133 )
2134 })
2135 .collect()
2136}
2137
2138fn from_completion_edits(
2139 editor_edits: &[(Range<Anchor>, Arc<str>)],
2140 buffer: &Entity<Buffer>,
2141 cx: &App,
2142) -> Vec<(Range<usize>, Arc<str>)> {
2143 let buffer = buffer.read(cx);
2144 editor_edits
2145 .iter()
2146 .map(|(range, text)| {
2147 (
2148 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2149 text.clone(),
2150 )
2151 })
2152 .collect()
2153}
2154
2155#[gpui::test]
2156async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2157 init_test(cx);
2158
2159 let fs = FakeFs::new(cx.executor());
2160 fs.insert_tree(
2161 "/project",
2162 serde_json::json!({
2163 "main.rs": "fn main() {\n \n}\n"
2164 }),
2165 )
2166 .await;
2167
2168 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2169
2170 let http_client = FakeHttpClient::create(|_req| async move {
2171 Ok(gpui::http_client::Response::builder()
2172 .status(401)
2173 .body("Unauthorized".into())
2174 .unwrap())
2175 });
2176
2177 let client =
2178 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2179 cx.update(|cx| {
2180 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2181 });
2182
2183 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2184
2185 let buffer = project
2186 .update(cx, |project, cx| {
2187 let path = project
2188 .find_project_path(path!("/project/main.rs"), cx)
2189 .unwrap();
2190 project.open_buffer(path, cx)
2191 })
2192 .await
2193 .unwrap();
2194
2195 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2196 ep_store.update(cx, |ep_store, cx| {
2197 ep_store.register_buffer(&buffer, &project, cx)
2198 });
2199 cx.background_executor.run_until_parked();
2200
2201 let completion_task = ep_store.update(cx, |ep_store, cx| {
2202 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2203 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2204 });
2205
2206 let result = completion_task.await;
2207 assert!(
2208 result.is_err(),
2209 "Without authentication and without custom URL, prediction should fail"
2210 );
2211}
2212
2213#[gpui::test]
2214async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2215 init_test(cx);
2216
2217 let fs = FakeFs::new(cx.executor());
2218 fs.insert_tree(
2219 "/project",
2220 serde_json::json!({
2221 "main.rs": "fn main() {\n \n}\n"
2222 }),
2223 )
2224 .await;
2225
2226 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2227
2228 let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2229 let predict_called_clone = predict_called.clone();
2230
2231 let http_client = FakeHttpClient::create({
2232 move |req| {
2233 let uri = req.uri().path().to_string();
2234 let predict_called = predict_called_clone.clone();
2235 async move {
2236 if uri.contains("predict") {
2237 predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2238 Ok(gpui::http_client::Response::builder()
2239 .body(
2240 serde_json::to_string(&open_ai::Response {
2241 id: "test-123".to_string(),
2242 object: "chat.completion".to_string(),
2243 created: 0,
2244 model: "test".to_string(),
2245 usage: open_ai::Usage {
2246 prompt_tokens: 0,
2247 completion_tokens: 0,
2248 total_tokens: 0,
2249 },
2250 choices: vec![open_ai::Choice {
2251 index: 0,
2252 message: open_ai::RequestMessage::Assistant {
2253 content: Some(open_ai::MessageContent::Plain(
2254 indoc! {"
2255 ```main.rs
2256 <|start_of_file|>
2257 <|editable_region_start|>
2258 fn main() {
2259 println!(\"Hello, world!\");
2260 }
2261 <|editable_region_end|>
2262 ```
2263 "}
2264 .to_string(),
2265 )),
2266 tool_calls: vec![],
2267 },
2268 finish_reason: Some("stop".to_string()),
2269 }],
2270 })
2271 .unwrap()
2272 .into(),
2273 )
2274 .unwrap())
2275 } else {
2276 Ok(gpui::http_client::Response::builder()
2277 .status(401)
2278 .body("Unauthorized".into())
2279 .unwrap())
2280 }
2281 }
2282 }
2283 });
2284
2285 let client =
2286 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2287 cx.update(|cx| {
2288 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2289 });
2290
2291 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2292
2293 let buffer = project
2294 .update(cx, |project, cx| {
2295 let path = project
2296 .find_project_path(path!("/project/main.rs"), cx)
2297 .unwrap();
2298 project.open_buffer(path, cx)
2299 })
2300 .await
2301 .unwrap();
2302
2303 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2304 ep_store.update(cx, |ep_store, cx| {
2305 ep_store.register_buffer(&buffer, &project, cx)
2306 });
2307 cx.background_executor.run_until_parked();
2308
2309 let completion_task = ep_store.update(cx, |ep_store, cx| {
2310 ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2311 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2312 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2313 });
2314
2315 let _ = completion_task.await;
2316
2317 assert!(
2318 predict_called.load(std::sync::atomic::Ordering::SeqCst),
2319 "With custom URL, predict endpoint should be called even without authentication"
2320 );
2321}
2322
2323#[gpui::test]
2324fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2325 let buffer = cx.new(|cx| {
2326 Buffer::local(
2327 indoc! {"
2328 zero
2329 one
2330 two
2331 three
2332 four
2333 five
2334 six
2335 seven
2336 eight
2337 nine
2338 ten
2339 eleven
2340 twelve
2341 thirteen
2342 fourteen
2343 fifteen
2344 sixteen
2345 seventeen
2346 eighteen
2347 nineteen
2348 twenty
2349 twenty-one
2350 twenty-two
2351 twenty-three
2352 twenty-four
2353 "},
2354 cx,
2355 )
2356 });
2357
2358 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2359
2360 buffer.update(cx, |buffer, cx| {
2361 let point = Point::new(12, 0);
2362 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2363 let point = Point::new(8, 0);
2364 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2365 });
2366
2367 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2368
2369 let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2370
2371 assert_eq!(
2372 diff,
2373 indoc! {"
2374 @@ -6,10 +6,12 @@
2375 five
2376 six
2377 seven
2378 +FIRST INSERTION
2379 eight
2380 nine
2381 ten
2382 eleven
2383 +SECOND INSERTION
2384 twelve
2385 thirteen
2386 fourteen
2387 "}
2388 );
2389}
2390
2391#[ctor::ctor]
2392fn init_logger() {
2393 zlog::init_test();
2394}