1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
3use client::{UserStore, test::FakeServer};
4use clock::{FakeSystemClock, ReplicaId};
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
8 RejectEditPredictionsBody,
9 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
10};
11use futures::{
12 AsyncReadExt, StreamExt,
13 channel::{mpsc, oneshot},
14};
15use gpui::{
16 Entity, TestAppContext,
17 http_client::{FakeHttpClient, Response},
18};
19use indoc::indoc;
20use language::Point;
21use lsp::LanguageServerId;
22use parking_lot::Mutex;
23use pretty_assertions::{assert_eq, assert_matches};
24use project::{FakeFs, Project};
25use serde_json::json;
26use settings::SettingsStore;
27use std::{path::Path, sync::Arc, time::Duration};
28use util::{path, rel_path::rel_path};
29use uuid::Uuid;
30use zeta_prompt::ZetaPromptInput;
31
32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
33
34#[gpui::test]
35async fn test_current_state(cx: &mut TestAppContext) {
36 let (ep_store, mut requests) = init_test_with_fake_client(cx);
37 let fs = FakeFs::new(cx.executor());
38 fs.insert_tree(
39 "/root",
40 json!({
41 "1.txt": "Hello!\nHow\nBye\n",
42 "2.txt": "Hola!\nComo\nAdios\n"
43 }),
44 )
45 .await;
46 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
47
48 let buffer1 = project
49 .update(cx, |project, cx| {
50 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
51 project.set_active_path(Some(path.clone()), cx);
52 project.open_buffer(path, cx)
53 })
54 .await
55 .unwrap();
56 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
57 let position = snapshot1.anchor_before(language::Point::new(1, 3));
58
59 ep_store.update(cx, |ep_store, cx| {
60 ep_store.register_project(&project, cx);
61 ep_store.register_buffer(&buffer1, &project, cx);
62 });
63
64 // Prediction for current file
65
66 ep_store.update(cx, |ep_store, cx| {
67 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
68 });
69 let (request, respond_tx) = requests.predict.next().await.unwrap();
70
71 respond_tx
72 .send(model_response(
73 &request,
74 indoc! {r"
75 --- a/root/1.txt
76 +++ b/root/1.txt
77 @@ ... @@
78 Hello!
79 -How
80 +How are you?
81 Bye
82 "},
83 ))
84 .unwrap();
85
86 cx.run_until_parked();
87
88 ep_store.update(cx, |ep_store, cx| {
89 let prediction = ep_store
90 .prediction_at(&buffer1, None, &project, cx)
91 .unwrap();
92 assert_matches!(prediction, BufferEditPrediction::Local { .. });
93 });
94
95 ep_store.update(cx, |ep_store, cx| {
96 ep_store.reject_current_prediction(
97 EditPredictionRejectReason::Discarded,
98 &project,
99 edit_prediction_types::EditPredictionDismissReason::Ignored,
100 cx,
101 );
102 });
103
104 // Prediction for diagnostic in another file
105
106 let diagnostic = lsp::Diagnostic {
107 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
108 severity: Some(lsp::DiagnosticSeverity::ERROR),
109 message: "Sentence is incomplete".to_string(),
110 ..Default::default()
111 };
112
113 project.update(cx, |project, cx| {
114 project.lsp_store().update(cx, |lsp_store, cx| {
115 lsp_store
116 .update_diagnostics(
117 LanguageServerId(0),
118 lsp::PublishDiagnosticsParams {
119 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
120 diagnostics: vec![diagnostic],
121 version: None,
122 },
123 None,
124 language::DiagnosticSourceKind::Pushed,
125 &[],
126 cx,
127 )
128 .unwrap();
129 });
130 });
131
132 let (request, respond_tx) = requests.predict.next().await.unwrap();
133 respond_tx
134 .send(model_response(
135 &request,
136 indoc! {r#"
137 --- a/root/2.txt
138 +++ b/root/2.txt
139 @@ ... @@
140 Hola!
141 -Como
142 +Como estas?
143 Adios
144 "#},
145 ))
146 .unwrap();
147 cx.run_until_parked();
148
149 ep_store.update(cx, |ep_store, cx| {
150 let prediction = ep_store
151 .prediction_at(&buffer1, None, &project, cx)
152 .unwrap();
153 assert_matches!(
154 prediction,
155 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
156 );
157 });
158
159 let buffer2 = project
160 .update(cx, |project, cx| {
161 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
162 project.open_buffer(path, cx)
163 })
164 .await
165 .unwrap();
166
167 ep_store.update(cx, |ep_store, cx| {
168 let prediction = ep_store
169 .prediction_at(&buffer2, None, &project, cx)
170 .unwrap();
171 assert_matches!(prediction, BufferEditPrediction::Local { .. });
172 });
173}
174
175#[gpui::test]
176async fn test_simple_request(cx: &mut TestAppContext) {
177 let (ep_store, mut requests) = init_test_with_fake_client(cx);
178 let fs = FakeFs::new(cx.executor());
179 fs.insert_tree(
180 "/root",
181 json!({
182 "foo.md": "Hello!\nHow\nBye\n"
183 }),
184 )
185 .await;
186 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
187
188 let buffer = project
189 .update(cx, |project, cx| {
190 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
191 project.open_buffer(path, cx)
192 })
193 .await
194 .unwrap();
195 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
196 let position = snapshot.anchor_before(language::Point::new(1, 3));
197
198 let prediction_task = ep_store.update(cx, |ep_store, cx| {
199 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
200 });
201
202 let (request, respond_tx) = requests.predict.next().await.unwrap();
203
204 // TODO Put back when we have a structured request again
205 // assert_eq!(
206 // request.excerpt_path.as_ref(),
207 // Path::new(path!("root/foo.md"))
208 // );
209 // assert_eq!(
210 // request.cursor_point,
211 // Point {
212 // line: Line(1),
213 // column: 3
214 // }
215 // );
216
217 respond_tx
218 .send(model_response(
219 &request,
220 indoc! { r"
221 --- a/root/foo.md
222 +++ b/root/foo.md
223 @@ ... @@
224 Hello!
225 -How
226 +How are you?
227 Bye
228 "},
229 ))
230 .unwrap();
231
232 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
233
234 assert_eq!(prediction.edits.len(), 1);
235 assert_eq!(
236 prediction.edits[0].0.to_point(&snapshot).start,
237 language::Point::new(1, 3)
238 );
239 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
240}
241
242#[gpui::test]
243async fn test_request_events(cx: &mut TestAppContext) {
244 let (ep_store, mut requests) = init_test_with_fake_client(cx);
245 let fs = FakeFs::new(cx.executor());
246 fs.insert_tree(
247 "/root",
248 json!({
249 "foo.md": "Hello!\n\nBye\n"
250 }),
251 )
252 .await;
253 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
254
255 let buffer = project
256 .update(cx, |project, cx| {
257 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
258 project.open_buffer(path, cx)
259 })
260 .await
261 .unwrap();
262
263 ep_store.update(cx, |ep_store, cx| {
264 ep_store.register_buffer(&buffer, &project, cx);
265 });
266
267 buffer.update(cx, |buffer, cx| {
268 buffer.edit(vec![(7..7, "How")], None, cx);
269 });
270
271 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
272 let position = snapshot.anchor_before(language::Point::new(1, 3));
273
274 let prediction_task = ep_store.update(cx, |ep_store, cx| {
275 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
276 });
277
278 let (request, respond_tx) = requests.predict.next().await.unwrap();
279
280 let prompt = prompt_from_request(&request);
281 assert!(
282 prompt.contains(indoc! {"
283 --- a/root/foo.md
284 +++ b/root/foo.md
285 @@ -1,3 +1,3 @@
286 Hello!
287 -
288 +How
289 Bye
290 "}),
291 "{prompt}"
292 );
293
294 respond_tx
295 .send(model_response(
296 &request,
297 indoc! {r#"
298 --- a/root/foo.md
299 +++ b/root/foo.md
300 @@ ... @@
301 Hello!
302 -How
303 +How are you?
304 Bye
305 "#},
306 ))
307 .unwrap();
308
309 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
310
311 assert_eq!(prediction.edits.len(), 1);
312 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
313}
314
315#[gpui::test]
316async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
317 let (ep_store, _requests) = init_test_with_fake_client(cx);
318 let fs = FakeFs::new(cx.executor());
319 fs.insert_tree(
320 "/root",
321 json!({
322 "foo.md": "Hello!\n\nBye\n"
323 }),
324 )
325 .await;
326 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
327
328 let buffer = project
329 .update(cx, |project, cx| {
330 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
331 project.open_buffer(path, cx)
332 })
333 .await
334 .unwrap();
335
336 ep_store.update(cx, |ep_store, cx| {
337 ep_store.register_buffer(&buffer, &project, cx);
338 });
339
340 // First burst: insert "How"
341 buffer.update(cx, |buffer, cx| {
342 buffer.edit(vec![(7..7, "How")], None, cx);
343 });
344
345 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
346 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
347 cx.run_until_parked();
348
349 // Second burst: append " are you?" immediately after "How" on the same line.
350 //
351 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
352 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
353 buffer.update(cx, |buffer, cx| {
354 buffer.edit(vec![(10..10, " are you?")], None, cx);
355 });
356
357 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
358 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
359 buffer.update(cx, |buffer, cx| {
360 buffer.edit(vec![(19..19, "!")], None, cx);
361 });
362
363 // Without time-based splitting, there is one event.
364 let events = ep_store.update(cx, |ep_store, cx| {
365 ep_store.edit_history_for_project(&project, cx)
366 });
367 assert_eq!(events.len(), 1);
368 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
369 assert_eq!(
370 diff.as_str(),
371 indoc! {"
372 @@ -1,3 +1,3 @@
373 Hello!
374 -
375 +How are you?!
376 Bye
377 "}
378 );
379
380 // With time-based splitting, there are two distinct events.
381 let events = ep_store.update(cx, |ep_store, cx| {
382 ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
383 });
384 assert_eq!(events.len(), 2);
385 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
386 assert_eq!(
387 diff.as_str(),
388 indoc! {"
389 @@ -1,3 +1,3 @@
390 Hello!
391 -
392 +How
393 Bye
394 "}
395 );
396
397 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
398 assert_eq!(
399 diff.as_str(),
400 indoc! {"
401 @@ -1,3 +1,3 @@
402 Hello!
403 -How
404 +How are you?!
405 Bye
406 "}
407 );
408}
409
410#[gpui::test]
411async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
412 let (ep_store, _requests) = init_test_with_fake_client(cx);
413 let fs = FakeFs::new(cx.executor());
414
415 // Create a file with 30 lines to test line-based coalescing
416 let content = (1..=30)
417 .map(|i| format!("Line {}\n", i))
418 .collect::<String>();
419 fs.insert_tree(
420 "/root",
421 json!({
422 "foo.md": content
423 }),
424 )
425 .await;
426 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
427
428 let buffer = project
429 .update(cx, |project, cx| {
430 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
431 project.open_buffer(path, cx)
432 })
433 .await
434 .unwrap();
435
436 ep_store.update(cx, |ep_store, cx| {
437 ep_store.register_buffer(&buffer, &project, cx);
438 });
439
440 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
441 buffer.update(cx, |buffer, cx| {
442 let start = Point::new(10, 0).to_offset(buffer);
443 let end = Point::new(13, 0).to_offset(buffer);
444 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
445 });
446
447 let events = ep_store.update(cx, |ep_store, cx| {
448 ep_store.edit_history_for_project(&project, cx)
449 });
450 assert_eq!(
451 render_events(&events),
452 indoc! {"
453 @@ -8,9 +8,8 @@
454 Line 8
455 Line 9
456 Line 10
457 -Line 11
458 -Line 12
459 -Line 13
460 +Middle A
461 +Middle B
462 Line 14
463 Line 15
464 Line 16
465 "},
466 "After first edit"
467 );
468
469 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
470 // This tests that coalescing considers the START of the existing range
471 buffer.update(cx, |buffer, cx| {
472 let offset = Point::new(5, 0).to_offset(buffer);
473 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
474 });
475
476 let events = ep_store.update(cx, |ep_store, cx| {
477 ep_store.edit_history_for_project(&project, cx)
478 });
479 assert_eq!(
480 render_events(&events),
481 indoc! {"
482 @@ -3,14 +3,14 @@
483 Line 3
484 Line 4
485 Line 5
486 +Above
487 Line 6
488 Line 7
489 Line 8
490 Line 9
491 Line 10
492 -Line 11
493 -Line 12
494 -Line 13
495 +Middle A
496 +Middle B
497 Line 14
498 Line 15
499 Line 16
500 "},
501 "After inserting above (should coalesce)"
502 );
503
504 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
505 // This tests that coalescing considers the END of the existing range
506 buffer.update(cx, |buffer, cx| {
507 let offset = Point::new(14, 0).to_offset(buffer);
508 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
509 });
510
511 let events = ep_store.update(cx, |ep_store, cx| {
512 ep_store.edit_history_for_project(&project, cx)
513 });
514 assert_eq!(
515 render_events(&events),
516 indoc! {"
517 @@ -3,15 +3,16 @@
518 Line 3
519 Line 4
520 Line 5
521 +Above
522 Line 6
523 Line 7
524 Line 8
525 Line 9
526 Line 10
527 -Line 11
528 -Line 12
529 -Line 13
530 +Middle A
531 +Middle B
532 Line 14
533 +Below
534 Line 15
535 Line 16
536 Line 17
537 "},
538 "After inserting below (should coalesce)"
539 );
540
541 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
542 // This should NOT coalesce - creates a new event
543 buffer.update(cx, |buffer, cx| {
544 let offset = Point::new(25, 0).to_offset(buffer);
545 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
546 });
547
548 let events = ep_store.update(cx, |ep_store, cx| {
549 ep_store.edit_history_for_project(&project, cx)
550 });
551 assert_eq!(
552 render_events(&events),
553 indoc! {"
554 @@ -3,15 +3,16 @@
555 Line 3
556 Line 4
557 Line 5
558 +Above
559 Line 6
560 Line 7
561 Line 8
562 Line 9
563 Line 10
564 -Line 11
565 -Line 12
566 -Line 13
567 +Middle A
568 +Middle B
569 Line 14
570 +Below
571 Line 15
572 Line 16
573 Line 17
574
575 ---
576 @@ -23,6 +23,7 @@
577 Line 22
578 Line 23
579 Line 24
580 +Far below
581 Line 25
582 Line 26
583 Line 27
584 "},
585 "After inserting far below (should NOT coalesce)"
586 );
587}
588
589fn render_events(events: &[StoredEvent]) -> String {
590 events
591 .iter()
592 .map(|e| {
593 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
594 diff.as_str()
595 })
596 .collect::<Vec<_>>()
597 .join("\n---\n")
598}
599
600#[gpui::test]
601async fn test_empty_prediction(cx: &mut TestAppContext) {
602 let (ep_store, mut requests) = init_test_with_fake_client(cx);
603 let fs = FakeFs::new(cx.executor());
604 fs.insert_tree(
605 "/root",
606 json!({
607 "foo.md": "Hello!\nHow\nBye\n"
608 }),
609 )
610 .await;
611 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
612
613 let buffer = project
614 .update(cx, |project, cx| {
615 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
616 project.open_buffer(path, cx)
617 })
618 .await
619 .unwrap();
620 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
621 let position = snapshot.anchor_before(language::Point::new(1, 3));
622
623 ep_store.update(cx, |ep_store, cx| {
624 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
625 });
626
627 let (request, respond_tx) = requests.predict.next().await.unwrap();
628 let response = model_response(&request, "");
629 let id = response.request_id.clone();
630 respond_tx.send(response).unwrap();
631
632 cx.run_until_parked();
633
634 ep_store.update(cx, |ep_store, cx| {
635 assert!(
636 ep_store
637 .prediction_at(&buffer, None, &project, cx)
638 .is_none()
639 );
640 });
641
642 // prediction is reported as rejected
643 let (reject_request, _) = requests.reject.next().await.unwrap();
644
645 assert_eq!(
646 &reject_request.rejections,
647 &[EditPredictionRejection {
648 request_id: id,
649 reason: EditPredictionRejectReason::Empty,
650 was_shown: false
651 }]
652 );
653}
654
655#[gpui::test]
656async fn test_interpolated_empty(cx: &mut TestAppContext) {
657 let (ep_store, mut requests) = init_test_with_fake_client(cx);
658 let fs = FakeFs::new(cx.executor());
659 fs.insert_tree(
660 "/root",
661 json!({
662 "foo.md": "Hello!\nHow\nBye\n"
663 }),
664 )
665 .await;
666 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
667
668 let buffer = project
669 .update(cx, |project, cx| {
670 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
671 project.open_buffer(path, cx)
672 })
673 .await
674 .unwrap();
675 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
676 let position = snapshot.anchor_before(language::Point::new(1, 3));
677
678 ep_store.update(cx, |ep_store, cx| {
679 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
680 });
681
682 let (request, respond_tx) = requests.predict.next().await.unwrap();
683
684 buffer.update(cx, |buffer, cx| {
685 buffer.set_text("Hello!\nHow are you?\nBye", cx);
686 });
687
688 let response = model_response(&request, SIMPLE_DIFF);
689 let id = response.request_id.clone();
690 respond_tx.send(response).unwrap();
691
692 cx.run_until_parked();
693
694 ep_store.update(cx, |ep_store, cx| {
695 assert!(
696 ep_store
697 .prediction_at(&buffer, None, &project, cx)
698 .is_none()
699 );
700 });
701
702 // prediction is reported as rejected
703 let (reject_request, _) = requests.reject.next().await.unwrap();
704
705 assert_eq!(
706 &reject_request.rejections,
707 &[EditPredictionRejection {
708 request_id: id,
709 reason: EditPredictionRejectReason::InterpolatedEmpty,
710 was_shown: false
711 }]
712 );
713}
714
715const SIMPLE_DIFF: &str = indoc! { r"
716 --- a/root/foo.md
717 +++ b/root/foo.md
718 @@ ... @@
719 Hello!
720 -How
721 +How are you?
722 Bye
723"};
724
725#[gpui::test]
726async fn test_replace_current(cx: &mut TestAppContext) {
727 let (ep_store, mut requests) = init_test_with_fake_client(cx);
728 let fs = FakeFs::new(cx.executor());
729 fs.insert_tree(
730 "/root",
731 json!({
732 "foo.md": "Hello!\nHow\nBye\n"
733 }),
734 )
735 .await;
736 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
737
738 let buffer = project
739 .update(cx, |project, cx| {
740 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
741 project.open_buffer(path, cx)
742 })
743 .await
744 .unwrap();
745 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
746 let position = snapshot.anchor_before(language::Point::new(1, 3));
747
748 ep_store.update(cx, |ep_store, cx| {
749 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
750 });
751
752 let (request, respond_tx) = requests.predict.next().await.unwrap();
753 let first_response = model_response(&request, SIMPLE_DIFF);
754 let first_id = first_response.request_id.clone();
755 respond_tx.send(first_response).unwrap();
756
757 cx.run_until_parked();
758
759 ep_store.update(cx, |ep_store, cx| {
760 assert_eq!(
761 ep_store
762 .prediction_at(&buffer, None, &project, cx)
763 .unwrap()
764 .id
765 .0,
766 first_id
767 );
768 });
769
770 // a second request is triggered
771 ep_store.update(cx, |ep_store, cx| {
772 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
773 });
774
775 let (request, respond_tx) = requests.predict.next().await.unwrap();
776 let second_response = model_response(&request, SIMPLE_DIFF);
777 let second_id = second_response.request_id.clone();
778 respond_tx.send(second_response).unwrap();
779
780 cx.run_until_parked();
781
782 ep_store.update(cx, |ep_store, cx| {
783 // second replaces first
784 assert_eq!(
785 ep_store
786 .prediction_at(&buffer, None, &project, cx)
787 .unwrap()
788 .id
789 .0,
790 second_id
791 );
792 });
793
794 // first is reported as replaced
795 let (reject_request, _) = requests.reject.next().await.unwrap();
796
797 assert_eq!(
798 &reject_request.rejections,
799 &[EditPredictionRejection {
800 request_id: first_id,
801 reason: EditPredictionRejectReason::Replaced,
802 was_shown: false
803 }]
804 );
805}
806
807#[gpui::test]
808async fn test_current_preferred(cx: &mut TestAppContext) {
809 let (ep_store, mut requests) = init_test_with_fake_client(cx);
810 let fs = FakeFs::new(cx.executor());
811 fs.insert_tree(
812 "/root",
813 json!({
814 "foo.md": "Hello!\nHow\nBye\n"
815 }),
816 )
817 .await;
818 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
819
820 let buffer = project
821 .update(cx, |project, cx| {
822 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
823 project.open_buffer(path, cx)
824 })
825 .await
826 .unwrap();
827 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
828 let position = snapshot.anchor_before(language::Point::new(1, 3));
829
830 ep_store.update(cx, |ep_store, cx| {
831 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
832 });
833
834 let (request, respond_tx) = requests.predict.next().await.unwrap();
835 let first_response = model_response(&request, SIMPLE_DIFF);
836 let first_id = first_response.request_id.clone();
837 respond_tx.send(first_response).unwrap();
838
839 cx.run_until_parked();
840
841 ep_store.update(cx, |ep_store, cx| {
842 assert_eq!(
843 ep_store
844 .prediction_at(&buffer, None, &project, cx)
845 .unwrap()
846 .id
847 .0,
848 first_id
849 );
850 });
851
852 // a second request is triggered
853 ep_store.update(cx, |ep_store, cx| {
854 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
855 });
856
857 let (request, respond_tx) = requests.predict.next().await.unwrap();
858 // worse than current prediction
859 let second_response = model_response(
860 &request,
861 indoc! { r"
862 --- a/root/foo.md
863 +++ b/root/foo.md
864 @@ ... @@
865 Hello!
866 -How
867 +How are
868 Bye
869 "},
870 );
871 let second_id = second_response.request_id.clone();
872 respond_tx.send(second_response).unwrap();
873
874 cx.run_until_parked();
875
876 ep_store.update(cx, |ep_store, cx| {
877 // first is preferred over second
878 assert_eq!(
879 ep_store
880 .prediction_at(&buffer, None, &project, cx)
881 .unwrap()
882 .id
883 .0,
884 first_id
885 );
886 });
887
888 // second is reported as rejected
889 let (reject_request, _) = requests.reject.next().await.unwrap();
890
891 assert_eq!(
892 &reject_request.rejections,
893 &[EditPredictionRejection {
894 request_id: second_id,
895 reason: EditPredictionRejectReason::CurrentPreferred,
896 was_shown: false
897 }]
898 );
899}
900
901#[gpui::test]
902async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
903 let (ep_store, mut requests) = init_test_with_fake_client(cx);
904 let fs = FakeFs::new(cx.executor());
905 fs.insert_tree(
906 "/root",
907 json!({
908 "foo.md": "Hello!\nHow\nBye\n"
909 }),
910 )
911 .await;
912 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
913
914 let buffer = project
915 .update(cx, |project, cx| {
916 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
917 project.open_buffer(path, cx)
918 })
919 .await
920 .unwrap();
921 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
922 let position = snapshot.anchor_before(language::Point::new(1, 3));
923
924 // start two refresh tasks
925 ep_store.update(cx, |ep_store, cx| {
926 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
927 });
928
929 let (request1, respond_first) = requests.predict.next().await.unwrap();
930
931 ep_store.update(cx, |ep_store, cx| {
932 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
933 });
934
935 let (request, respond_second) = requests.predict.next().await.unwrap();
936
937 // wait for throttle
938 cx.run_until_parked();
939
940 // second responds first
941 let second_response = model_response(&request, SIMPLE_DIFF);
942 let second_id = second_response.request_id.clone();
943 respond_second.send(second_response).unwrap();
944
945 cx.run_until_parked();
946
947 ep_store.update(cx, |ep_store, cx| {
948 // current prediction is second
949 assert_eq!(
950 ep_store
951 .prediction_at(&buffer, None, &project, cx)
952 .unwrap()
953 .id
954 .0,
955 second_id
956 );
957 });
958
959 let first_response = model_response(&request1, SIMPLE_DIFF);
960 let first_id = first_response.request_id.clone();
961 respond_first.send(first_response).unwrap();
962
963 cx.run_until_parked();
964
965 ep_store.update(cx, |ep_store, cx| {
966 // current prediction is still second, since first was cancelled
967 assert_eq!(
968 ep_store
969 .prediction_at(&buffer, None, &project, cx)
970 .unwrap()
971 .id
972 .0,
973 second_id
974 );
975 });
976
977 // first is reported as rejected
978 let (reject_request, _) = requests.reject.next().await.unwrap();
979
980 cx.run_until_parked();
981
982 assert_eq!(
983 &reject_request.rejections,
984 &[EditPredictionRejection {
985 request_id: first_id,
986 reason: EditPredictionRejectReason::Canceled,
987 was_shown: false
988 }]
989 );
990}
991
992#[gpui::test]
993async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
994 let (ep_store, mut requests) = init_test_with_fake_client(cx);
995 let fs = FakeFs::new(cx.executor());
996 fs.insert_tree(
997 "/root",
998 json!({
999 "foo.md": "Hello!\nHow\nBye\n"
1000 }),
1001 )
1002 .await;
1003 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1004
1005 let buffer = project
1006 .update(cx, |project, cx| {
1007 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1008 project.open_buffer(path, cx)
1009 })
1010 .await
1011 .unwrap();
1012 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1013 let position = snapshot.anchor_before(language::Point::new(1, 3));
1014
1015 // start two refresh tasks
1016 ep_store.update(cx, |ep_store, cx| {
1017 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1018 });
1019
1020 let (request1, respond_first) = requests.predict.next().await.unwrap();
1021
1022 ep_store.update(cx, |ep_store, cx| {
1023 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1024 });
1025
1026 let (request2, respond_second) = requests.predict.next().await.unwrap();
1027
1028 // wait for throttle, so requests are sent
1029 cx.run_until_parked();
1030
1031 ep_store.update(cx, |ep_store, cx| {
1032 // start a third request
1033 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1034
1035 // 2 are pending, so 2nd is cancelled
1036 assert_eq!(
1037 ep_store
1038 .get_or_init_project(&project, cx)
1039 .cancelled_predictions
1040 .iter()
1041 .copied()
1042 .collect::<Vec<_>>(),
1043 [1]
1044 );
1045 });
1046
1047 // wait for throttle
1048 cx.run_until_parked();
1049
1050 let (request3, respond_third) = requests.predict.next().await.unwrap();
1051
1052 let first_response = model_response(&request1, SIMPLE_DIFF);
1053 let first_id = first_response.request_id.clone();
1054 respond_first.send(first_response).unwrap();
1055
1056 cx.run_until_parked();
1057
1058 ep_store.update(cx, |ep_store, cx| {
1059 // current prediction is first
1060 assert_eq!(
1061 ep_store
1062 .prediction_at(&buffer, None, &project, cx)
1063 .unwrap()
1064 .id
1065 .0,
1066 first_id
1067 );
1068 });
1069
1070 let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1071 let cancelled_id = cancelled_response.request_id.clone();
1072 respond_second.send(cancelled_response).unwrap();
1073
1074 cx.run_until_parked();
1075
1076 ep_store.update(cx, |ep_store, cx| {
1077 // current prediction is still first, since second was cancelled
1078 assert_eq!(
1079 ep_store
1080 .prediction_at(&buffer, None, &project, cx)
1081 .unwrap()
1082 .id
1083 .0,
1084 first_id
1085 );
1086 });
1087
1088 let third_response = model_response(&request3, SIMPLE_DIFF);
1089 let third_response_id = third_response.request_id.clone();
1090 respond_third.send(third_response).unwrap();
1091
1092 cx.run_until_parked();
1093
1094 ep_store.update(cx, |ep_store, cx| {
1095 // third completes and replaces first
1096 assert_eq!(
1097 ep_store
1098 .prediction_at(&buffer, None, &project, cx)
1099 .unwrap()
1100 .id
1101 .0,
1102 third_response_id
1103 );
1104 });
1105
1106 // second is reported as rejected
1107 let (reject_request, _) = requests.reject.next().await.unwrap();
1108
1109 cx.run_until_parked();
1110
1111 assert_eq!(
1112 &reject_request.rejections,
1113 &[
1114 EditPredictionRejection {
1115 request_id: cancelled_id,
1116 reason: EditPredictionRejectReason::Canceled,
1117 was_shown: false
1118 },
1119 EditPredictionRejection {
1120 request_id: first_id,
1121 reason: EditPredictionRejectReason::Replaced,
1122 was_shown: false
1123 }
1124 ]
1125 );
1126}
1127
1128#[gpui::test]
1129async fn test_rejections_flushing(cx: &mut TestAppContext) {
1130 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1131
1132 ep_store.update(cx, |ep_store, cx| {
1133 ep_store.reject_prediction(
1134 EditPredictionId("test-1".into()),
1135 EditPredictionRejectReason::Discarded,
1136 false,
1137 edit_prediction_types::EditPredictionDismissReason::Ignored,
1138 cx,
1139 );
1140 ep_store.reject_prediction(
1141 EditPredictionId("test-2".into()),
1142 EditPredictionRejectReason::Canceled,
1143 true,
1144 edit_prediction_types::EditPredictionDismissReason::Ignored,
1145 cx,
1146 );
1147 });
1148
1149 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1150 cx.run_until_parked();
1151
1152 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1153 respond_tx.send(()).unwrap();
1154
1155 // batched
1156 assert_eq!(reject_request.rejections.len(), 2);
1157 assert_eq!(
1158 reject_request.rejections[0],
1159 EditPredictionRejection {
1160 request_id: "test-1".to_string(),
1161 reason: EditPredictionRejectReason::Discarded,
1162 was_shown: false
1163 }
1164 );
1165 assert_eq!(
1166 reject_request.rejections[1],
1167 EditPredictionRejection {
1168 request_id: "test-2".to_string(),
1169 reason: EditPredictionRejectReason::Canceled,
1170 was_shown: true
1171 }
1172 );
1173
1174 // Reaching batch size limit sends without debounce
1175 ep_store.update(cx, |ep_store, cx| {
1176 for i in 0..70 {
1177 ep_store.reject_prediction(
1178 EditPredictionId(format!("batch-{}", i).into()),
1179 EditPredictionRejectReason::Discarded,
1180 false,
1181 edit_prediction_types::EditPredictionDismissReason::Ignored,
1182 cx,
1183 );
1184 }
1185 });
1186
1187 // First MAX/2 items are sent immediately
1188 cx.run_until_parked();
1189 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1190 respond_tx.send(()).unwrap();
1191
1192 assert_eq!(reject_request.rejections.len(), 50);
1193 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1194 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1195
1196 // Remaining items are debounced with the next batch
1197 cx.executor().advance_clock(Duration::from_secs(15));
1198 cx.run_until_parked();
1199
1200 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1201 respond_tx.send(()).unwrap();
1202
1203 assert_eq!(reject_request.rejections.len(), 20);
1204 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1205 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1206
1207 // Request failure
1208 ep_store.update(cx, |ep_store, cx| {
1209 ep_store.reject_prediction(
1210 EditPredictionId("retry-1".into()),
1211 EditPredictionRejectReason::Discarded,
1212 false,
1213 edit_prediction_types::EditPredictionDismissReason::Ignored,
1214 cx,
1215 );
1216 });
1217
1218 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1219 cx.run_until_parked();
1220
1221 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1222 assert_eq!(reject_request.rejections.len(), 1);
1223 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1224 // Simulate failure
1225 drop(_respond_tx);
1226
1227 // Add another rejection
1228 ep_store.update(cx, |ep_store, cx| {
1229 ep_store.reject_prediction(
1230 EditPredictionId("retry-2".into()),
1231 EditPredictionRejectReason::Discarded,
1232 false,
1233 edit_prediction_types::EditPredictionDismissReason::Ignored,
1234 cx,
1235 );
1236 });
1237
1238 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1239 cx.run_until_parked();
1240
1241 // Retry should include both the failed item and the new one
1242 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1243 respond_tx.send(()).unwrap();
1244
1245 assert_eq!(reject_request.rejections.len(), 2);
1246 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1247 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1248}
1249
1250// Skipped until we start including diagnostics in prompt
1251// #[gpui::test]
1252// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1253// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1254// let fs = FakeFs::new(cx.executor());
1255// fs.insert_tree(
1256// "/root",
1257// json!({
1258// "foo.md": "Hello!\nBye"
1259// }),
1260// )
1261// .await;
1262// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1263
1264// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1265// let diagnostic = lsp::Diagnostic {
1266// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1267// severity: Some(lsp::DiagnosticSeverity::ERROR),
1268// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1269// ..Default::default()
1270// };
1271
1272// project.update(cx, |project, cx| {
1273// project.lsp_store().update(cx, |lsp_store, cx| {
1274// // Create some diagnostics
1275// lsp_store
1276// .update_diagnostics(
1277// LanguageServerId(0),
1278// lsp::PublishDiagnosticsParams {
1279// uri: path_to_buffer_uri.clone(),
1280// diagnostics: vec![diagnostic],
1281// version: None,
1282// },
1283// None,
1284// language::DiagnosticSourceKind::Pushed,
1285// &[],
1286// cx,
1287// )
1288// .unwrap();
1289// });
1290// });
1291
1292// let buffer = project
1293// .update(cx, |project, cx| {
1294// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1295// project.open_buffer(path, cx)
1296// })
1297// .await
1298// .unwrap();
1299
1300// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1301// let position = snapshot.anchor_before(language::Point::new(0, 0));
1302
1303// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1304// ep_store.request_prediction(&project, &buffer, position, cx)
1305// });
1306
1307// let (request, _respond_tx) = req_rx.next().await.unwrap();
1308
1309// assert_eq!(request.diagnostic_groups.len(), 1);
1310// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1311// .unwrap();
1312// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1313// assert_eq!(
1314// value,
1315// json!({
1316// "entries": [{
1317// "range": {
1318// "start": 8,
1319// "end": 10
1320// },
1321// "diagnostic": {
1322// "source": null,
1323// "code": null,
1324// "code_description": null,
1325// "severity": 1,
1326// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1327// "markdown": null,
1328// "group_id": 0,
1329// "is_primary": true,
1330// "is_disk_based": false,
1331// "is_unnecessary": false,
1332// "source_kind": "Pushed",
1333// "data": null,
1334// "underline": true
1335// }
1336// }],
1337// "primary_ix": 0
1338// })
1339// );
1340// }
1341
1342// Generate a model response that would apply the given diff to the active file.
1343fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1344 let excerpt =
1345 request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1346 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1347
1348 PredictEditsV3Response {
1349 request_id: Uuid::new_v4().to_string(),
1350 output: new_excerpt,
1351 }
1352}
1353
1354fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1355 zeta_prompt::format_zeta_prompt(&request.input, request.prompt_version)
1356}
1357
1358struct RequestChannels {
1359 predict: mpsc::UnboundedReceiver<(
1360 PredictEditsV3Request,
1361 oneshot::Sender<PredictEditsV3Response>,
1362 )>,
1363 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1364}
1365
1366fn init_test_with_fake_client(
1367 cx: &mut TestAppContext,
1368) -> (Entity<EditPredictionStore>, RequestChannels) {
1369 cx.update(move |cx| {
1370 let settings_store = SettingsStore::test(cx);
1371 cx.set_global(settings_store);
1372 zlog::init_test();
1373
1374 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1375 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1376
1377 let http_client = FakeHttpClient::create({
1378 move |req| {
1379 let uri = req.uri().path().to_string();
1380 let mut body = req.into_body();
1381 let predict_req_tx = predict_req_tx.clone();
1382 let reject_req_tx = reject_req_tx.clone();
1383 async move {
1384 let resp = match uri.as_str() {
1385 "/client/llm_tokens" => serde_json::to_string(&json!({
1386 "token": "test"
1387 }))
1388 .unwrap(),
1389 "/predict_edits/v3" => {
1390 let mut buf = Vec::new();
1391 body.read_to_end(&mut buf).await.ok();
1392 let req = serde_json::from_slice(&buf).unwrap();
1393
1394 let (res_tx, res_rx) = oneshot::channel();
1395 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1396 serde_json::to_string(&res_rx.await?).unwrap()
1397 }
1398 "/predict_edits/reject" => {
1399 let mut buf = Vec::new();
1400 body.read_to_end(&mut buf).await.ok();
1401 let req = serde_json::from_slice(&buf).unwrap();
1402
1403 let (res_tx, res_rx) = oneshot::channel();
1404 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1405 serde_json::to_string(&res_rx.await?).unwrap()
1406 }
1407 _ => {
1408 panic!("Unexpected path: {}", uri)
1409 }
1410 };
1411
1412 Ok(Response::builder().body(resp.into()).unwrap())
1413 }
1414 }
1415 });
1416
1417 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1418 client.cloud_client().set_credentials(1, "test".into());
1419
1420 language_model::init(client.clone(), cx);
1421
1422 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1423 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1424
1425 (
1426 ep_store,
1427 RequestChannels {
1428 predict: predict_req_rx,
1429 reject: reject_req_rx,
1430 },
1431 )
1432 })
1433}
1434
1435const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1436
1437#[gpui::test]
1438async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1439 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1440 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1441 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1442 });
1443
1444 let edit_preview = cx
1445 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1446 .await;
1447
1448 let prediction = EditPrediction {
1449 edits,
1450 edit_preview,
1451 buffer: buffer.clone(),
1452 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1453 id: EditPredictionId("the-id".into()),
1454 inputs: ZetaPromptInput {
1455 events: Default::default(),
1456 related_files: Default::default(),
1457 cursor_path: Path::new("").into(),
1458 cursor_excerpt: "".into(),
1459 editable_range_in_excerpt: 0..0,
1460 cursor_offset_in_excerpt: 0,
1461 },
1462 buffer_snapshotted_at: Instant::now(),
1463 response_received_at: Instant::now(),
1464 };
1465
1466 cx.update(|cx| {
1467 assert_eq!(
1468 from_completion_edits(
1469 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1470 &buffer,
1471 cx
1472 ),
1473 vec![(2..5, "REM".into()), (9..11, "".into())]
1474 );
1475
1476 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1477 assert_eq!(
1478 from_completion_edits(
1479 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1480 &buffer,
1481 cx
1482 ),
1483 vec![(2..2, "REM".into()), (6..8, "".into())]
1484 );
1485
1486 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1487 assert_eq!(
1488 from_completion_edits(
1489 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1490 &buffer,
1491 cx
1492 ),
1493 vec![(2..5, "REM".into()), (9..11, "".into())]
1494 );
1495
1496 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1497 assert_eq!(
1498 from_completion_edits(
1499 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1500 &buffer,
1501 cx
1502 ),
1503 vec![(3..3, "EM".into()), (7..9, "".into())]
1504 );
1505
1506 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1507 assert_eq!(
1508 from_completion_edits(
1509 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1510 &buffer,
1511 cx
1512 ),
1513 vec![(4..4, "M".into()), (8..10, "".into())]
1514 );
1515
1516 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1517 assert_eq!(
1518 from_completion_edits(
1519 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1520 &buffer,
1521 cx
1522 ),
1523 vec![(9..11, "".into())]
1524 );
1525
1526 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1527 assert_eq!(
1528 from_completion_edits(
1529 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1530 &buffer,
1531 cx
1532 ),
1533 vec![(4..4, "M".into()), (8..10, "".into())]
1534 );
1535
1536 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1537 assert_eq!(
1538 from_completion_edits(
1539 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1540 &buffer,
1541 cx
1542 ),
1543 vec![(4..4, "M".into())]
1544 );
1545
1546 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1547 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1548 })
1549}
1550
1551#[gpui::test]
1552async fn test_clean_up_diff(cx: &mut TestAppContext) {
1553 init_test(cx);
1554
1555 assert_eq!(
1556 apply_edit_prediction(
1557 indoc! {"
1558 fn main() {
1559 let word_1 = \"lorem\";
1560 let range = word.len()..word.len();
1561 }
1562 "},
1563 indoc! {"
1564 <|editable_region_start|>
1565 fn main() {
1566 let word_1 = \"lorem\";
1567 let range = word_1.len()..word_1.len();
1568 }
1569
1570 <|editable_region_end|>
1571 "},
1572 cx,
1573 )
1574 .await,
1575 indoc! {"
1576 fn main() {
1577 let word_1 = \"lorem\";
1578 let range = word_1.len()..word_1.len();
1579 }
1580 "},
1581 );
1582
1583 assert_eq!(
1584 apply_edit_prediction(
1585 indoc! {"
1586 fn main() {
1587 let story = \"the quick\"
1588 }
1589 "},
1590 indoc! {"
1591 <|editable_region_start|>
1592 fn main() {
1593 let story = \"the quick brown fox jumps over the lazy dog\";
1594 }
1595
1596 <|editable_region_end|>
1597 "},
1598 cx,
1599 )
1600 .await,
1601 indoc! {"
1602 fn main() {
1603 let story = \"the quick brown fox jumps over the lazy dog\";
1604 }
1605 "},
1606 );
1607}
1608
1609#[gpui::test]
1610async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1611 init_test(cx);
1612
1613 let buffer_content = "lorem\n";
1614 let completion_response = indoc! {"
1615 ```animals.js
1616 <|start_of_file|>
1617 <|editable_region_start|>
1618 lorem
1619 ipsum
1620 <|editable_region_end|>
1621 ```"};
1622
1623 assert_eq!(
1624 apply_edit_prediction(buffer_content, completion_response, cx).await,
1625 "lorem\nipsum"
1626 );
1627}
1628
1629#[gpui::test]
1630async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1631 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1632 // When the buffer ends without a trailing newline, but the model returns output
1633 // with a trailing newline, zeta2 should normalize both sides before diffing
1634 // so no spurious newline is inserted.
1635 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1636 let fs = FakeFs::new(cx.executor());
1637
1638 // Single line buffer with no trailing newline
1639 fs.insert_tree(
1640 "/root",
1641 json!({
1642 "foo.txt": "hello"
1643 }),
1644 )
1645 .await;
1646 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1647
1648 let buffer = project
1649 .update(cx, |project, cx| {
1650 let path = project
1651 .find_project_path(path!("root/foo.txt"), cx)
1652 .unwrap();
1653 project.open_buffer(path, cx)
1654 })
1655 .await
1656 .unwrap();
1657
1658 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1659 let position = snapshot.anchor_before(language::Point::new(0, 5));
1660
1661 ep_store.update(cx, |ep_store, cx| {
1662 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1663 });
1664
1665 let (_request, respond_tx) = requests.predict.next().await.unwrap();
1666
1667 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1668 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1669 let response = PredictEditsV3Response {
1670 request_id: Uuid::new_v4().to_string(),
1671 output: "hello world\n".to_string(),
1672 };
1673 respond_tx.send(response).unwrap();
1674
1675 cx.run_until_parked();
1676
1677 // The prediction should insert " world" without adding a newline
1678 ep_store.update(cx, |ep_store, cx| {
1679 let prediction = ep_store
1680 .prediction_at(&buffer, None, &project, cx)
1681 .expect("should have prediction");
1682 let edits: Vec<_> = prediction
1683 .edits
1684 .iter()
1685 .map(|(range, text)| {
1686 let snapshot = buffer.read(cx).snapshot();
1687 (range.to_offset(&snapshot), text.clone())
1688 })
1689 .collect();
1690 assert_eq!(edits, vec![(5..5, " world".into())]);
1691 });
1692}
1693
1694#[gpui::test]
1695async fn test_can_collect_data(cx: &mut TestAppContext) {
1696 init_test(cx);
1697
1698 let fs = project::FakeFs::new(cx.executor());
1699 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1700 .await;
1701
1702 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1703 let buffer = project
1704 .update(cx, |project, cx| {
1705 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1706 })
1707 .await
1708 .unwrap();
1709
1710 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1711 ep_store.update(cx, |ep_store, _cx| {
1712 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1713 });
1714
1715 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1716 assert_eq!(
1717 captured_request.lock().clone().unwrap().can_collect_data,
1718 true
1719 );
1720
1721 ep_store.update(cx, |ep_store, _cx| {
1722 ep_store.data_collection_choice = DataCollectionChoice::Disabled
1723 });
1724
1725 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1726 assert_eq!(
1727 captured_request.lock().clone().unwrap().can_collect_data,
1728 false
1729 );
1730}
1731
1732#[gpui::test]
1733async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1734 init_test(cx);
1735
1736 let fs = project::FakeFs::new(cx.executor());
1737 let project = Project::test(fs.clone(), [], cx).await;
1738
1739 let buffer = cx.new(|_cx| {
1740 Buffer::remote(
1741 language::BufferId::new(1).unwrap(),
1742 ReplicaId::new(1),
1743 language::Capability::ReadWrite,
1744 "fn main() {\n println!(\"Hello\");\n}",
1745 )
1746 });
1747
1748 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1749 ep_store.update(cx, |ep_store, _cx| {
1750 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1751 });
1752
1753 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1754 assert_eq!(
1755 captured_request.lock().clone().unwrap().can_collect_data,
1756 false
1757 );
1758}
1759
1760#[gpui::test]
1761async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1762 init_test(cx);
1763
1764 let fs = project::FakeFs::new(cx.executor());
1765 fs.insert_tree(
1766 path!("/project"),
1767 json!({
1768 "LICENSE": BSD_0_TXT,
1769 ".env": "SECRET_KEY=secret"
1770 }),
1771 )
1772 .await;
1773
1774 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1775 let buffer = project
1776 .update(cx, |project, cx| {
1777 project.open_local_buffer("/project/.env", cx)
1778 })
1779 .await
1780 .unwrap();
1781
1782 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1783 ep_store.update(cx, |ep_store, _cx| {
1784 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1785 });
1786
1787 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1788 assert_eq!(
1789 captured_request.lock().clone().unwrap().can_collect_data,
1790 false
1791 );
1792}
1793
1794#[gpui::test]
1795async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1796 init_test(cx);
1797
1798 let fs = project::FakeFs::new(cx.executor());
1799 let project = Project::test(fs.clone(), [], cx).await;
1800 let buffer = cx.new(|cx| Buffer::local("", cx));
1801
1802 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1803 ep_store.update(cx, |ep_store, _cx| {
1804 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1805 });
1806
1807 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1808 assert_eq!(
1809 captured_request.lock().clone().unwrap().can_collect_data,
1810 false
1811 );
1812}
1813
1814#[gpui::test]
1815async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1816 init_test(cx);
1817
1818 let fs = project::FakeFs::new(cx.executor());
1819 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1820 .await;
1821
1822 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1823 let buffer = project
1824 .update(cx, |project, cx| {
1825 project.open_local_buffer("/project/main.rs", cx)
1826 })
1827 .await
1828 .unwrap();
1829
1830 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1831 ep_store.update(cx, |ep_store, _cx| {
1832 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1833 });
1834
1835 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1836 assert_eq!(
1837 captured_request.lock().clone().unwrap().can_collect_data,
1838 false
1839 );
1840}
1841
1842#[gpui::test]
1843async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1844 init_test(cx);
1845
1846 let fs = project::FakeFs::new(cx.executor());
1847 fs.insert_tree(
1848 path!("/open_source_worktree"),
1849 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1850 )
1851 .await;
1852 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1853 .await;
1854
1855 let project = Project::test(
1856 fs.clone(),
1857 [
1858 path!("/open_source_worktree").as_ref(),
1859 path!("/closed_source_worktree").as_ref(),
1860 ],
1861 cx,
1862 )
1863 .await;
1864 let buffer = project
1865 .update(cx, |project, cx| {
1866 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1867 })
1868 .await
1869 .unwrap();
1870
1871 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1872 ep_store.update(cx, |ep_store, _cx| {
1873 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1874 });
1875
1876 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1877 assert_eq!(
1878 captured_request.lock().clone().unwrap().can_collect_data,
1879 true
1880 );
1881
1882 let closed_source_file = project
1883 .update(cx, |project, cx| {
1884 let worktree2 = project
1885 .worktree_for_root_name("closed_source_worktree", cx)
1886 .unwrap();
1887 worktree2.update(cx, |worktree2, cx| {
1888 worktree2.load_file(rel_path("main.rs"), cx)
1889 })
1890 })
1891 .await
1892 .unwrap()
1893 .file;
1894
1895 buffer.update(cx, |buffer, cx| {
1896 buffer.file_updated(closed_source_file, cx);
1897 });
1898
1899 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1900 assert_eq!(
1901 captured_request.lock().clone().unwrap().can_collect_data,
1902 false
1903 );
1904}
1905
1906#[gpui::test]
1907async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1908 init_test(cx);
1909
1910 let fs = project::FakeFs::new(cx.executor());
1911 fs.insert_tree(
1912 path!("/worktree1"),
1913 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1914 )
1915 .await;
1916 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1917 .await;
1918
1919 let project = Project::test(
1920 fs.clone(),
1921 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1922 cx,
1923 )
1924 .await;
1925 let buffer = project
1926 .update(cx, |project, cx| {
1927 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1928 })
1929 .await
1930 .unwrap();
1931 let private_buffer = project
1932 .update(cx, |project, cx| {
1933 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1934 })
1935 .await
1936 .unwrap();
1937
1938 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1939 ep_store.update(cx, |ep_store, _cx| {
1940 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1941 });
1942
1943 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1944 assert_eq!(
1945 captured_request.lock().clone().unwrap().can_collect_data,
1946 true
1947 );
1948
1949 // this has a side effect of registering the buffer to watch for edits
1950 run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1951 assert_eq!(
1952 captured_request.lock().clone().unwrap().can_collect_data,
1953 false
1954 );
1955
1956 private_buffer.update(cx, |private_buffer, cx| {
1957 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1958 });
1959
1960 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1961 assert_eq!(
1962 captured_request.lock().clone().unwrap().can_collect_data,
1963 false
1964 );
1965
1966 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1967 // included
1968 buffer.update(cx, |buffer, cx| {
1969 buffer.edit(
1970 [(
1971 0..0,
1972 " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1973 )],
1974 None,
1975 cx,
1976 );
1977 });
1978
1979 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1980 assert_eq!(
1981 captured_request.lock().clone().unwrap().can_collect_data,
1982 true
1983 );
1984}
1985
1986fn init_test(cx: &mut TestAppContext) {
1987 cx.update(|cx| {
1988 let settings_store = SettingsStore::test(cx);
1989 cx.set_global(settings_store);
1990 });
1991}
1992
1993async fn apply_edit_prediction(
1994 buffer_content: &str,
1995 completion_response: &str,
1996 cx: &mut TestAppContext,
1997) -> String {
1998 let fs = project::FakeFs::new(cx.executor());
1999 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2000 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2001 let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
2002 *response.lock() = completion_response.to_string();
2003 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2004 buffer.update(cx, |buffer, cx| {
2005 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2006 });
2007 buffer.read_with(cx, |buffer, _| buffer.text())
2008}
2009
2010async fn run_edit_prediction(
2011 buffer: &Entity<Buffer>,
2012 project: &Entity<Project>,
2013 ep_store: &Entity<EditPredictionStore>,
2014 cx: &mut TestAppContext,
2015) -> EditPrediction {
2016 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2017 ep_store.update(cx, |ep_store, cx| {
2018 ep_store.register_buffer(buffer, &project, cx)
2019 });
2020 cx.background_executor.run_until_parked();
2021 let prediction_task = ep_store.update(cx, |ep_store, cx| {
2022 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2023 });
2024 prediction_task.await.unwrap().unwrap().prediction.unwrap()
2025}
2026
2027async fn make_test_ep_store(
2028 project: &Entity<Project>,
2029 cx: &mut TestAppContext,
2030) -> (
2031 Entity<EditPredictionStore>,
2032 Arc<Mutex<Option<PredictEditsBody>>>,
2033 Arc<Mutex<String>>,
2034) {
2035 let default_response = indoc! {"
2036 ```main.rs
2037 <|start_of_file|>
2038 <|editable_region_start|>
2039 hello world
2040 <|editable_region_end|>
2041 ```"
2042 };
2043 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2044 let completion_response: Arc<Mutex<String>> =
2045 Arc::new(Mutex::new(default_response.to_string()));
2046 let http_client = FakeHttpClient::create({
2047 let captured_request = captured_request.clone();
2048 let completion_response = completion_response.clone();
2049 let mut next_request_id = 0;
2050 move |req| {
2051 let captured_request = captured_request.clone();
2052 let completion_response = completion_response.clone();
2053 async move {
2054 match (req.method(), req.uri().path()) {
2055 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2056 .status(200)
2057 .body(
2058 serde_json::to_string(&CreateLlmTokenResponse {
2059 token: LlmToken("the-llm-token".to_string()),
2060 })
2061 .unwrap()
2062 .into(),
2063 )
2064 .unwrap()),
2065 (&Method::POST, "/predict_edits/v2") => {
2066 let mut request_body = String::new();
2067 req.into_body().read_to_string(&mut request_body).await?;
2068 *captured_request.lock() =
2069 Some(serde_json::from_str(&request_body).unwrap());
2070 next_request_id += 1;
2071 Ok(http_client::Response::builder()
2072 .status(200)
2073 .body(
2074 serde_json::to_string(&PredictEditsResponse {
2075 request_id: format!("request-{next_request_id}"),
2076 output_excerpt: completion_response.lock().clone(),
2077 })
2078 .unwrap()
2079 .into(),
2080 )
2081 .unwrap())
2082 }
2083 _ => Ok(http_client::Response::builder()
2084 .status(404)
2085 .body("Not Found".into())
2086 .unwrap()),
2087 }
2088 }
2089 }
2090 });
2091
2092 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2093 cx.update(|cx| {
2094 RefreshLlmTokenListener::register(client.clone(), cx);
2095 });
2096 let _server = FakeServer::for_client(42, &client, cx).await;
2097
2098 let ep_store = cx.new(|cx| {
2099 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2100 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2101
2102 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2103 for worktree in worktrees {
2104 let worktree_id = worktree.read(cx).id();
2105 ep_store
2106 .get_or_init_project(project, cx)
2107 .license_detection_watchers
2108 .entry(worktree_id)
2109 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2110 }
2111
2112 ep_store
2113 });
2114
2115 (ep_store, captured_request, completion_response)
2116}
2117
2118fn to_completion_edits(
2119 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2120 buffer: &Entity<Buffer>,
2121 cx: &App,
2122) -> Vec<(Range<Anchor>, Arc<str>)> {
2123 let buffer = buffer.read(cx);
2124 iterator
2125 .into_iter()
2126 .map(|(range, text)| {
2127 (
2128 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2129 text,
2130 )
2131 })
2132 .collect()
2133}
2134
2135fn from_completion_edits(
2136 editor_edits: &[(Range<Anchor>, Arc<str>)],
2137 buffer: &Entity<Buffer>,
2138 cx: &App,
2139) -> Vec<(Range<usize>, Arc<str>)> {
2140 let buffer = buffer.read(cx);
2141 editor_edits
2142 .iter()
2143 .map(|(range, text)| {
2144 (
2145 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2146 text.clone(),
2147 )
2148 })
2149 .collect()
2150}
2151
2152#[gpui::test]
2153async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2154 init_test(cx);
2155
2156 let fs = FakeFs::new(cx.executor());
2157 fs.insert_tree(
2158 "/project",
2159 serde_json::json!({
2160 "main.rs": "fn main() {\n \n}\n"
2161 }),
2162 )
2163 .await;
2164
2165 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2166
2167 let http_client = FakeHttpClient::create(|_req| async move {
2168 Ok(gpui::http_client::Response::builder()
2169 .status(401)
2170 .body("Unauthorized".into())
2171 .unwrap())
2172 });
2173
2174 let client =
2175 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2176 cx.update(|cx| {
2177 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2178 });
2179
2180 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2181
2182 let buffer = project
2183 .update(cx, |project, cx| {
2184 let path = project
2185 .find_project_path(path!("/project/main.rs"), cx)
2186 .unwrap();
2187 project.open_buffer(path, cx)
2188 })
2189 .await
2190 .unwrap();
2191
2192 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2193 ep_store.update(cx, |ep_store, cx| {
2194 ep_store.register_buffer(&buffer, &project, cx)
2195 });
2196 cx.background_executor.run_until_parked();
2197
2198 let completion_task = ep_store.update(cx, |ep_store, cx| {
2199 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2200 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2201 });
2202
2203 let result = completion_task.await;
2204 assert!(
2205 result.is_err(),
2206 "Without authentication and without custom URL, prediction should fail"
2207 );
2208}
2209
2210#[gpui::test]
2211async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2212 init_test(cx);
2213
2214 let fs = FakeFs::new(cx.executor());
2215 fs.insert_tree(
2216 "/project",
2217 serde_json::json!({
2218 "main.rs": "fn main() {\n \n}\n"
2219 }),
2220 )
2221 .await;
2222
2223 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2224
2225 let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2226 let predict_called_clone = predict_called.clone();
2227
2228 let http_client = FakeHttpClient::create({
2229 move |req| {
2230 let uri = req.uri().path().to_string();
2231 let predict_called = predict_called_clone.clone();
2232 async move {
2233 if uri.contains("predict") {
2234 predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2235 Ok(gpui::http_client::Response::builder()
2236 .body(
2237 serde_json::to_string(&open_ai::Response {
2238 id: "test-123".to_string(),
2239 object: "chat.completion".to_string(),
2240 created: 0,
2241 model: "test".to_string(),
2242 usage: open_ai::Usage {
2243 prompt_tokens: 0,
2244 completion_tokens: 0,
2245 total_tokens: 0,
2246 },
2247 choices: vec![open_ai::Choice {
2248 index: 0,
2249 message: open_ai::RequestMessage::Assistant {
2250 content: Some(open_ai::MessageContent::Plain(
2251 indoc! {"
2252 ```main.rs
2253 <|start_of_file|>
2254 <|editable_region_start|>
2255 fn main() {
2256 println!(\"Hello, world!\");
2257 }
2258 <|editable_region_end|>
2259 ```
2260 "}
2261 .to_string(),
2262 )),
2263 tool_calls: vec![],
2264 },
2265 finish_reason: Some("stop".to_string()),
2266 }],
2267 })
2268 .unwrap()
2269 .into(),
2270 )
2271 .unwrap())
2272 } else {
2273 Ok(gpui::http_client::Response::builder()
2274 .status(401)
2275 .body("Unauthorized".into())
2276 .unwrap())
2277 }
2278 }
2279 }
2280 });
2281
2282 let client =
2283 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2284 cx.update(|cx| {
2285 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2286 });
2287
2288 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2289
2290 let buffer = project
2291 .update(cx, |project, cx| {
2292 let path = project
2293 .find_project_path(path!("/project/main.rs"), cx)
2294 .unwrap();
2295 project.open_buffer(path, cx)
2296 })
2297 .await
2298 .unwrap();
2299
2300 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2301 ep_store.update(cx, |ep_store, cx| {
2302 ep_store.register_buffer(&buffer, &project, cx)
2303 });
2304 cx.background_executor.run_until_parked();
2305
2306 let completion_task = ep_store.update(cx, |ep_store, cx| {
2307 ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2308 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2309 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2310 });
2311
2312 let _ = completion_task.await;
2313
2314 assert!(
2315 predict_called.load(std::sync::atomic::Ordering::SeqCst),
2316 "With custom URL, predict endpoint should be called even without authentication"
2317 );
2318}
2319
2320#[gpui::test]
2321fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2322 let buffer = cx.new(|cx| {
2323 Buffer::local(
2324 indoc! {"
2325 zero
2326 one
2327 two
2328 three
2329 four
2330 five
2331 six
2332 seven
2333 eight
2334 nine
2335 ten
2336 eleven
2337 twelve
2338 thirteen
2339 fourteen
2340 fifteen
2341 sixteen
2342 seventeen
2343 eighteen
2344 nineteen
2345 twenty
2346 twenty-one
2347 twenty-two
2348 twenty-three
2349 twenty-four
2350 "},
2351 cx,
2352 )
2353 });
2354
2355 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2356
2357 buffer.update(cx, |buffer, cx| {
2358 let point = Point::new(12, 0);
2359 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2360 let point = Point::new(8, 0);
2361 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2362 });
2363
2364 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2365
2366 let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2367
2368 assert_eq!(
2369 diff,
2370 indoc! {"
2371 @@ -6,10 +6,12 @@
2372 five
2373 six
2374 seven
2375 +FIRST INSERTION
2376 eight
2377 nine
2378 ten
2379 eleven
2380 +SECOND INSERTION
2381 twelve
2382 thirteen
2383 fourteen
2384 "}
2385 );
2386}
2387
2388#[ctor::ctor]
2389fn init_logger() {
2390 zlog::init_test();
2391}