1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
3use client::{UserStore, test::FakeServer};
4use clock::{FakeSystemClock, ReplicaId};
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
8 RejectEditPredictionsBody,
9 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
10};
11use futures::{
12 AsyncReadExt, StreamExt,
13 channel::{mpsc, oneshot},
14};
15use gpui::App;
16use gpui::{
17 Entity, TestAppContext,
18 http_client::{FakeHttpClient, Response},
19};
20use indoc::indoc;
21use language::{Buffer, Point};
22use lsp::LanguageServerId;
23use parking_lot::Mutex;
24use pretty_assertions::{assert_eq, assert_matches};
25use project::{FakeFs, Project};
26use serde_json::json;
27use settings::SettingsStore;
28use std::{path::Path, sync::Arc, time::Duration};
29use util::{path, rel_path::rel_path};
30use uuid::Uuid;
31use zeta_prompt::ZetaPromptInput;
32
33use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
34
35#[gpui::test]
36async fn test_current_state(cx: &mut TestAppContext) {
37 let (ep_store, mut requests) = init_test_with_fake_client(cx);
38 let fs = FakeFs::new(cx.executor());
39 fs.insert_tree(
40 "/root",
41 json!({
42 "1.txt": "Hello!\nHow\nBye\n",
43 "2.txt": "Hola!\nComo\nAdios\n"
44 }),
45 )
46 .await;
47 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
48
49 let buffer1 = project
50 .update(cx, |project, cx| {
51 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
52 project.set_active_path(Some(path.clone()), cx);
53 project.open_buffer(path, cx)
54 })
55 .await
56 .unwrap();
57 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
58 let position = snapshot1.anchor_before(language::Point::new(1, 3));
59
60 ep_store.update(cx, |ep_store, cx| {
61 ep_store.register_project(&project, cx);
62 ep_store.register_buffer(&buffer1, &project, cx);
63 });
64
65 // Prediction for current file
66
67 ep_store.update(cx, |ep_store, cx| {
68 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
69 });
70 let (request, respond_tx) = requests.predict.next().await.unwrap();
71
72 respond_tx
73 .send(model_response(
74 &request,
75 indoc! {r"
76 --- a/root/1.txt
77 +++ b/root/1.txt
78 @@ ... @@
79 Hello!
80 -How
81 +How are you?
82 Bye
83 "},
84 ))
85 .unwrap();
86
87 cx.run_until_parked();
88
89 ep_store.update(cx, |ep_store, cx| {
90 let prediction = ep_store
91 .prediction_at(&buffer1, None, &project, cx)
92 .unwrap();
93 assert_matches!(prediction, BufferEditPrediction::Local { .. });
94 });
95
96 ep_store.update(cx, |ep_store, _cx| {
97 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
98 });
99
100 // Prediction for diagnostic in another file
101
102 let diagnostic = lsp::Diagnostic {
103 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
104 severity: Some(lsp::DiagnosticSeverity::ERROR),
105 message: "Sentence is incomplete".to_string(),
106 ..Default::default()
107 };
108
109 project.update(cx, |project, cx| {
110 project.lsp_store().update(cx, |lsp_store, cx| {
111 lsp_store
112 .update_diagnostics(
113 LanguageServerId(0),
114 lsp::PublishDiagnosticsParams {
115 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
116 diagnostics: vec![diagnostic],
117 version: None,
118 },
119 None,
120 language::DiagnosticSourceKind::Pushed,
121 &[],
122 cx,
123 )
124 .unwrap();
125 });
126 });
127
128 let (request, respond_tx) = requests.predict.next().await.unwrap();
129 respond_tx
130 .send(model_response(
131 &request,
132 indoc! {r#"
133 --- a/root/2.txt
134 +++ b/root/2.txt
135 @@ ... @@
136 Hola!
137 -Como
138 +Como estas?
139 Adios
140 "#},
141 ))
142 .unwrap();
143 cx.run_until_parked();
144
145 ep_store.update(cx, |ep_store, cx| {
146 let prediction = ep_store
147 .prediction_at(&buffer1, None, &project, cx)
148 .unwrap();
149 assert_matches!(
150 prediction,
151 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
152 );
153 });
154
155 let buffer2 = project
156 .update(cx, |project, cx| {
157 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
158 project.open_buffer(path, cx)
159 })
160 .await
161 .unwrap();
162
163 ep_store.update(cx, |ep_store, cx| {
164 let prediction = ep_store
165 .prediction_at(&buffer2, None, &project, cx)
166 .unwrap();
167 assert_matches!(prediction, BufferEditPrediction::Local { .. });
168 });
169}
170
171#[gpui::test]
172async fn test_simple_request(cx: &mut TestAppContext) {
173 let (ep_store, mut requests) = init_test_with_fake_client(cx);
174 let fs = FakeFs::new(cx.executor());
175 fs.insert_tree(
176 "/root",
177 json!({
178 "foo.md": "Hello!\nHow\nBye\n"
179 }),
180 )
181 .await;
182 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
183
184 let buffer = project
185 .update(cx, |project, cx| {
186 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
187 project.open_buffer(path, cx)
188 })
189 .await
190 .unwrap();
191 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
192 let position = snapshot.anchor_before(language::Point::new(1, 3));
193
194 let prediction_task = ep_store.update(cx, |ep_store, cx| {
195 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
196 });
197
198 let (request, respond_tx) = requests.predict.next().await.unwrap();
199
200 // TODO Put back when we have a structured request again
201 // assert_eq!(
202 // request.excerpt_path.as_ref(),
203 // Path::new(path!("root/foo.md"))
204 // );
205 // assert_eq!(
206 // request.cursor_point,
207 // Point {
208 // line: Line(1),
209 // column: 3
210 // }
211 // );
212
213 respond_tx
214 .send(model_response(
215 &request,
216 indoc! { r"
217 --- a/root/foo.md
218 +++ b/root/foo.md
219 @@ ... @@
220 Hello!
221 -How
222 +How are you?
223 Bye
224 "},
225 ))
226 .unwrap();
227
228 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
229
230 assert_eq!(prediction.edits.len(), 1);
231 assert_eq!(
232 prediction.edits[0].0.to_point(&snapshot).start,
233 language::Point::new(1, 3)
234 );
235 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
236}
237
238#[gpui::test]
239async fn test_request_events(cx: &mut TestAppContext) {
240 let (ep_store, mut requests) = init_test_with_fake_client(cx);
241 let fs = FakeFs::new(cx.executor());
242 fs.insert_tree(
243 "/root",
244 json!({
245 "foo.md": "Hello!\n\nBye\n"
246 }),
247 )
248 .await;
249 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
250
251 let buffer = project
252 .update(cx, |project, cx| {
253 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
254 project.open_buffer(path, cx)
255 })
256 .await
257 .unwrap();
258
259 ep_store.update(cx, |ep_store, cx| {
260 ep_store.register_buffer(&buffer, &project, cx);
261 });
262
263 buffer.update(cx, |buffer, cx| {
264 buffer.edit(vec![(7..7, "How")], None, cx);
265 });
266
267 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
268 let position = snapshot.anchor_before(language::Point::new(1, 3));
269
270 let prediction_task = ep_store.update(cx, |ep_store, cx| {
271 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
272 });
273
274 let (request, respond_tx) = requests.predict.next().await.unwrap();
275
276 let prompt = prompt_from_request(&request);
277 assert!(
278 prompt.contains(indoc! {"
279 --- a/root/foo.md
280 +++ b/root/foo.md
281 @@ -1,3 +1,3 @@
282 Hello!
283 -
284 +How
285 Bye
286 "}),
287 "{prompt}"
288 );
289
290 respond_tx
291 .send(model_response(
292 &request,
293 indoc! {r#"
294 --- a/root/foo.md
295 +++ b/root/foo.md
296 @@ ... @@
297 Hello!
298 -How
299 +How are you?
300 Bye
301 "#},
302 ))
303 .unwrap();
304
305 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
306
307 assert_eq!(prediction.edits.len(), 1);
308 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
309}
310
311#[gpui::test]
312async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
313 let (ep_store, _requests) = init_test_with_fake_client(cx);
314 let fs = FakeFs::new(cx.executor());
315 fs.insert_tree(
316 "/root",
317 json!({
318 "foo.md": "Hello!\n\nBye\n"
319 }),
320 )
321 .await;
322 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
323
324 let buffer = project
325 .update(cx, |project, cx| {
326 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
327 project.open_buffer(path, cx)
328 })
329 .await
330 .unwrap();
331
332 ep_store.update(cx, |ep_store, cx| {
333 ep_store.register_buffer(&buffer, &project, cx);
334 });
335
336 // First burst: insert "How"
337 buffer.update(cx, |buffer, cx| {
338 buffer.edit(vec![(7..7, "How")], None, cx);
339 });
340
341 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
342 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
343 cx.run_until_parked();
344
345 // Second burst: append " are you?" immediately after "How" on the same line.
346 //
347 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
348 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
349 buffer.update(cx, |buffer, cx| {
350 buffer.edit(vec![(10..10, " are you?")], None, cx);
351 });
352
353 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
354 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
355 buffer.update(cx, |buffer, cx| {
356 buffer.edit(vec![(19..19, "!")], None, cx);
357 });
358
359 // Without time-based splitting, there is one event.
360 let events = ep_store.update(cx, |ep_store, cx| {
361 ep_store.edit_history_for_project(&project, cx)
362 });
363 assert_eq!(events.len(), 1);
364 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
365 assert_eq!(
366 diff.as_str(),
367 indoc! {"
368 @@ -1,3 +1,3 @@
369 Hello!
370 -
371 +How are you?!
372 Bye
373 "}
374 );
375
376 // With time-based splitting, there are two distinct events.
377 let events = ep_store.update(cx, |ep_store, cx| {
378 ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
379 });
380 assert_eq!(events.len(), 2);
381 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
382 assert_eq!(
383 diff.as_str(),
384 indoc! {"
385 @@ -1,3 +1,3 @@
386 Hello!
387 -
388 +How
389 Bye
390 "}
391 );
392
393 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
394 assert_eq!(
395 diff.as_str(),
396 indoc! {"
397 @@ -1,3 +1,3 @@
398 Hello!
399 -How
400 +How are you?!
401 Bye
402 "}
403 );
404}
405
406#[gpui::test]
407async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
408 let (ep_store, _requests) = init_test_with_fake_client(cx);
409 let fs = FakeFs::new(cx.executor());
410
411 // Create a file with 30 lines to test line-based coalescing
412 let content = (1..=30)
413 .map(|i| format!("Line {}\n", i))
414 .collect::<String>();
415 fs.insert_tree(
416 "/root",
417 json!({
418 "foo.md": content
419 }),
420 )
421 .await;
422 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
423
424 let buffer = project
425 .update(cx, |project, cx| {
426 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
427 project.open_buffer(path, cx)
428 })
429 .await
430 .unwrap();
431
432 ep_store.update(cx, |ep_store, cx| {
433 ep_store.register_buffer(&buffer, &project, cx);
434 });
435
436 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
437 buffer.update(cx, |buffer, cx| {
438 let start = Point::new(10, 0).to_offset(buffer);
439 let end = Point::new(13, 0).to_offset(buffer);
440 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
441 });
442
443 let events = ep_store.update(cx, |ep_store, cx| {
444 ep_store.edit_history_for_project(&project, cx)
445 });
446 assert_eq!(
447 render_events(&events),
448 indoc! {"
449 @@ -8,9 +8,8 @@
450 Line 8
451 Line 9
452 Line 10
453 -Line 11
454 -Line 12
455 -Line 13
456 +Middle A
457 +Middle B
458 Line 14
459 Line 15
460 Line 16
461 "},
462 "After first edit"
463 );
464
465 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
466 // This tests that coalescing considers the START of the existing range
467 buffer.update(cx, |buffer, cx| {
468 let offset = Point::new(5, 0).to_offset(buffer);
469 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
470 });
471
472 let events = ep_store.update(cx, |ep_store, cx| {
473 ep_store.edit_history_for_project(&project, cx)
474 });
475 assert_eq!(
476 render_events(&events),
477 indoc! {"
478 @@ -3,14 +3,14 @@
479 Line 3
480 Line 4
481 Line 5
482 +Above
483 Line 6
484 Line 7
485 Line 8
486 Line 9
487 Line 10
488 -Line 11
489 -Line 12
490 -Line 13
491 +Middle A
492 +Middle B
493 Line 14
494 Line 15
495 Line 16
496 "},
497 "After inserting above (should coalesce)"
498 );
499
500 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
501 // This tests that coalescing considers the END of the existing range
502 buffer.update(cx, |buffer, cx| {
503 let offset = Point::new(14, 0).to_offset(buffer);
504 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
505 });
506
507 let events = ep_store.update(cx, |ep_store, cx| {
508 ep_store.edit_history_for_project(&project, cx)
509 });
510 assert_eq!(
511 render_events(&events),
512 indoc! {"
513 @@ -3,15 +3,16 @@
514 Line 3
515 Line 4
516 Line 5
517 +Above
518 Line 6
519 Line 7
520 Line 8
521 Line 9
522 Line 10
523 -Line 11
524 -Line 12
525 -Line 13
526 +Middle A
527 +Middle B
528 Line 14
529 +Below
530 Line 15
531 Line 16
532 Line 17
533 "},
534 "After inserting below (should coalesce)"
535 );
536
537 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
538 // This should NOT coalesce - creates a new event
539 buffer.update(cx, |buffer, cx| {
540 let offset = Point::new(25, 0).to_offset(buffer);
541 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
542 });
543
544 let events = ep_store.update(cx, |ep_store, cx| {
545 ep_store.edit_history_for_project(&project, cx)
546 });
547 assert_eq!(
548 render_events(&events),
549 indoc! {"
550 @@ -3,15 +3,16 @@
551 Line 3
552 Line 4
553 Line 5
554 +Above
555 Line 6
556 Line 7
557 Line 8
558 Line 9
559 Line 10
560 -Line 11
561 -Line 12
562 -Line 13
563 +Middle A
564 +Middle B
565 Line 14
566 +Below
567 Line 15
568 Line 16
569 Line 17
570
571 ---
572 @@ -23,6 +23,7 @@
573 Line 22
574 Line 23
575 Line 24
576 +Far below
577 Line 25
578 Line 26
579 Line 27
580 "},
581 "After inserting far below (should NOT coalesce)"
582 );
583}
584
585fn render_events(events: &[StoredEvent]) -> String {
586 events
587 .iter()
588 .map(|e| {
589 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
590 diff.as_str()
591 })
592 .collect::<Vec<_>>()
593 .join("\n---\n")
594}
595
596#[gpui::test]
597async fn test_empty_prediction(cx: &mut TestAppContext) {
598 let (ep_store, mut requests) = init_test_with_fake_client(cx);
599 let fs = FakeFs::new(cx.executor());
600 fs.insert_tree(
601 "/root",
602 json!({
603 "foo.md": "Hello!\nHow\nBye\n"
604 }),
605 )
606 .await;
607 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
608
609 let buffer = project
610 .update(cx, |project, cx| {
611 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
612 project.open_buffer(path, cx)
613 })
614 .await
615 .unwrap();
616 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
617 let position = snapshot.anchor_before(language::Point::new(1, 3));
618
619 ep_store.update(cx, |ep_store, cx| {
620 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
621 });
622
623 let (request, respond_tx) = requests.predict.next().await.unwrap();
624 let response = model_response(&request, "");
625 let id = response.request_id.clone();
626 respond_tx.send(response).unwrap();
627
628 cx.run_until_parked();
629
630 ep_store.update(cx, |ep_store, cx| {
631 assert!(
632 ep_store
633 .prediction_at(&buffer, None, &project, cx)
634 .is_none()
635 );
636 });
637
638 // prediction is reported as rejected
639 let (reject_request, _) = requests.reject.next().await.unwrap();
640
641 assert_eq!(
642 &reject_request.rejections,
643 &[EditPredictionRejection {
644 request_id: id,
645 reason: EditPredictionRejectReason::Empty,
646 was_shown: false
647 }]
648 );
649}
650
651#[gpui::test]
652async fn test_interpolated_empty(cx: &mut TestAppContext) {
653 let (ep_store, mut requests) = init_test_with_fake_client(cx);
654 let fs = FakeFs::new(cx.executor());
655 fs.insert_tree(
656 "/root",
657 json!({
658 "foo.md": "Hello!\nHow\nBye\n"
659 }),
660 )
661 .await;
662 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
663
664 let buffer = project
665 .update(cx, |project, cx| {
666 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
667 project.open_buffer(path, cx)
668 })
669 .await
670 .unwrap();
671 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
672 let position = snapshot.anchor_before(language::Point::new(1, 3));
673
674 ep_store.update(cx, |ep_store, cx| {
675 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
676 });
677
678 let (request, respond_tx) = requests.predict.next().await.unwrap();
679
680 buffer.update(cx, |buffer, cx| {
681 buffer.set_text("Hello!\nHow are you?\nBye", cx);
682 });
683
684 let response = model_response(&request, SIMPLE_DIFF);
685 let id = response.request_id.clone();
686 respond_tx.send(response).unwrap();
687
688 cx.run_until_parked();
689
690 ep_store.update(cx, |ep_store, cx| {
691 assert!(
692 ep_store
693 .prediction_at(&buffer, None, &project, cx)
694 .is_none()
695 );
696 });
697
698 // prediction is reported as rejected
699 let (reject_request, _) = requests.reject.next().await.unwrap();
700
701 assert_eq!(
702 &reject_request.rejections,
703 &[EditPredictionRejection {
704 request_id: id,
705 reason: EditPredictionRejectReason::InterpolatedEmpty,
706 was_shown: false
707 }]
708 );
709}
710
711const SIMPLE_DIFF: &str = indoc! { r"
712 --- a/root/foo.md
713 +++ b/root/foo.md
714 @@ ... @@
715 Hello!
716 -How
717 +How are you?
718 Bye
719"};
720
721#[gpui::test]
722async fn test_replace_current(cx: &mut TestAppContext) {
723 let (ep_store, mut requests) = init_test_with_fake_client(cx);
724 let fs = FakeFs::new(cx.executor());
725 fs.insert_tree(
726 "/root",
727 json!({
728 "foo.md": "Hello!\nHow\nBye\n"
729 }),
730 )
731 .await;
732 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
733
734 let buffer = project
735 .update(cx, |project, cx| {
736 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
737 project.open_buffer(path, cx)
738 })
739 .await
740 .unwrap();
741 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
742 let position = snapshot.anchor_before(language::Point::new(1, 3));
743
744 ep_store.update(cx, |ep_store, cx| {
745 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
746 });
747
748 let (request, respond_tx) = requests.predict.next().await.unwrap();
749 let first_response = model_response(&request, SIMPLE_DIFF);
750 let first_id = first_response.request_id.clone();
751 respond_tx.send(first_response).unwrap();
752
753 cx.run_until_parked();
754
755 ep_store.update(cx, |ep_store, cx| {
756 assert_eq!(
757 ep_store
758 .prediction_at(&buffer, None, &project, cx)
759 .unwrap()
760 .id
761 .0,
762 first_id
763 );
764 });
765
766 // a second request is triggered
767 ep_store.update(cx, |ep_store, cx| {
768 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
769 });
770
771 let (request, respond_tx) = requests.predict.next().await.unwrap();
772 let second_response = model_response(&request, SIMPLE_DIFF);
773 let second_id = second_response.request_id.clone();
774 respond_tx.send(second_response).unwrap();
775
776 cx.run_until_parked();
777
778 ep_store.update(cx, |ep_store, cx| {
779 // second replaces first
780 assert_eq!(
781 ep_store
782 .prediction_at(&buffer, None, &project, cx)
783 .unwrap()
784 .id
785 .0,
786 second_id
787 );
788 });
789
790 // first is reported as replaced
791 let (reject_request, _) = requests.reject.next().await.unwrap();
792
793 assert_eq!(
794 &reject_request.rejections,
795 &[EditPredictionRejection {
796 request_id: first_id,
797 reason: EditPredictionRejectReason::Replaced,
798 was_shown: false
799 }]
800 );
801}
802
803#[gpui::test]
804async fn test_current_preferred(cx: &mut TestAppContext) {
805 let (ep_store, mut requests) = init_test_with_fake_client(cx);
806 let fs = FakeFs::new(cx.executor());
807 fs.insert_tree(
808 "/root",
809 json!({
810 "foo.md": "Hello!\nHow\nBye\n"
811 }),
812 )
813 .await;
814 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
815
816 let buffer = project
817 .update(cx, |project, cx| {
818 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
819 project.open_buffer(path, cx)
820 })
821 .await
822 .unwrap();
823 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
824 let position = snapshot.anchor_before(language::Point::new(1, 3));
825
826 ep_store.update(cx, |ep_store, cx| {
827 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
828 });
829
830 let (request, respond_tx) = requests.predict.next().await.unwrap();
831 let first_response = model_response(&request, SIMPLE_DIFF);
832 let first_id = first_response.request_id.clone();
833 respond_tx.send(first_response).unwrap();
834
835 cx.run_until_parked();
836
837 ep_store.update(cx, |ep_store, cx| {
838 assert_eq!(
839 ep_store
840 .prediction_at(&buffer, None, &project, cx)
841 .unwrap()
842 .id
843 .0,
844 first_id
845 );
846 });
847
848 // a second request is triggered
849 ep_store.update(cx, |ep_store, cx| {
850 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
851 });
852
853 let (request, respond_tx) = requests.predict.next().await.unwrap();
854 // worse than current prediction
855 let second_response = model_response(
856 &request,
857 indoc! { r"
858 --- a/root/foo.md
859 +++ b/root/foo.md
860 @@ ... @@
861 Hello!
862 -How
863 +How are
864 Bye
865 "},
866 );
867 let second_id = second_response.request_id.clone();
868 respond_tx.send(second_response).unwrap();
869
870 cx.run_until_parked();
871
872 ep_store.update(cx, |ep_store, cx| {
873 // first is preferred over second
874 assert_eq!(
875 ep_store
876 .prediction_at(&buffer, None, &project, cx)
877 .unwrap()
878 .id
879 .0,
880 first_id
881 );
882 });
883
884 // second is reported as rejected
885 let (reject_request, _) = requests.reject.next().await.unwrap();
886
887 assert_eq!(
888 &reject_request.rejections,
889 &[EditPredictionRejection {
890 request_id: second_id,
891 reason: EditPredictionRejectReason::CurrentPreferred,
892 was_shown: false
893 }]
894 );
895}
896
897#[gpui::test]
898async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
899 let (ep_store, mut requests) = init_test_with_fake_client(cx);
900 let fs = FakeFs::new(cx.executor());
901 fs.insert_tree(
902 "/root",
903 json!({
904 "foo.md": "Hello!\nHow\nBye\n"
905 }),
906 )
907 .await;
908 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
909
910 let buffer = project
911 .update(cx, |project, cx| {
912 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
913 project.open_buffer(path, cx)
914 })
915 .await
916 .unwrap();
917 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
918 let position = snapshot.anchor_before(language::Point::new(1, 3));
919
920 // start two refresh tasks
921 ep_store.update(cx, |ep_store, cx| {
922 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
923 });
924
925 let (request1, respond_first) = requests.predict.next().await.unwrap();
926
927 ep_store.update(cx, |ep_store, cx| {
928 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
929 });
930
931 let (request, respond_second) = requests.predict.next().await.unwrap();
932
933 // wait for throttle
934 cx.run_until_parked();
935
936 // second responds first
937 let second_response = model_response(&request, SIMPLE_DIFF);
938 let second_id = second_response.request_id.clone();
939 respond_second.send(second_response).unwrap();
940
941 cx.run_until_parked();
942
943 ep_store.update(cx, |ep_store, cx| {
944 // current prediction is second
945 assert_eq!(
946 ep_store
947 .prediction_at(&buffer, None, &project, cx)
948 .unwrap()
949 .id
950 .0,
951 second_id
952 );
953 });
954
955 let first_response = model_response(&request1, SIMPLE_DIFF);
956 let first_id = first_response.request_id.clone();
957 respond_first.send(first_response).unwrap();
958
959 cx.run_until_parked();
960
961 ep_store.update(cx, |ep_store, cx| {
962 // current prediction is still second, since first was cancelled
963 assert_eq!(
964 ep_store
965 .prediction_at(&buffer, None, &project, cx)
966 .unwrap()
967 .id
968 .0,
969 second_id
970 );
971 });
972
973 // first is reported as rejected
974 let (reject_request, _) = requests.reject.next().await.unwrap();
975
976 cx.run_until_parked();
977
978 assert_eq!(
979 &reject_request.rejections,
980 &[EditPredictionRejection {
981 request_id: first_id,
982 reason: EditPredictionRejectReason::Canceled,
983 was_shown: false
984 }]
985 );
986}
987
988#[gpui::test]
989async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
990 let (ep_store, mut requests) = init_test_with_fake_client(cx);
991 let fs = FakeFs::new(cx.executor());
992 fs.insert_tree(
993 "/root",
994 json!({
995 "foo.md": "Hello!\nHow\nBye\n"
996 }),
997 )
998 .await;
999 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1000
1001 let buffer = project
1002 .update(cx, |project, cx| {
1003 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1004 project.open_buffer(path, cx)
1005 })
1006 .await
1007 .unwrap();
1008 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1009 let position = snapshot.anchor_before(language::Point::new(1, 3));
1010
1011 // start two refresh tasks
1012 ep_store.update(cx, |ep_store, cx| {
1013 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1014 });
1015
1016 let (request1, respond_first) = requests.predict.next().await.unwrap();
1017
1018 ep_store.update(cx, |ep_store, cx| {
1019 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1020 });
1021
1022 let (request2, respond_second) = requests.predict.next().await.unwrap();
1023
1024 // wait for throttle, so requests are sent
1025 cx.run_until_parked();
1026
1027 ep_store.update(cx, |ep_store, cx| {
1028 // start a third request
1029 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1030
1031 // 2 are pending, so 2nd is cancelled
1032 assert_eq!(
1033 ep_store
1034 .get_or_init_project(&project, cx)
1035 .cancelled_predictions
1036 .iter()
1037 .copied()
1038 .collect::<Vec<_>>(),
1039 [1]
1040 );
1041 });
1042
1043 // wait for throttle
1044 cx.run_until_parked();
1045
1046 let (request3, respond_third) = requests.predict.next().await.unwrap();
1047
1048 let first_response = model_response(&request1, SIMPLE_DIFF);
1049 let first_id = first_response.request_id.clone();
1050 respond_first.send(first_response).unwrap();
1051
1052 cx.run_until_parked();
1053
1054 ep_store.update(cx, |ep_store, cx| {
1055 // current prediction is first
1056 assert_eq!(
1057 ep_store
1058 .prediction_at(&buffer, None, &project, cx)
1059 .unwrap()
1060 .id
1061 .0,
1062 first_id
1063 );
1064 });
1065
1066 let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1067 let cancelled_id = cancelled_response.request_id.clone();
1068 respond_second.send(cancelled_response).unwrap();
1069
1070 cx.run_until_parked();
1071
1072 ep_store.update(cx, |ep_store, cx| {
1073 // current prediction is still first, since second was cancelled
1074 assert_eq!(
1075 ep_store
1076 .prediction_at(&buffer, None, &project, cx)
1077 .unwrap()
1078 .id
1079 .0,
1080 first_id
1081 );
1082 });
1083
1084 let third_response = model_response(&request3, SIMPLE_DIFF);
1085 let third_response_id = third_response.request_id.clone();
1086 respond_third.send(third_response).unwrap();
1087
1088 cx.run_until_parked();
1089
1090 ep_store.update(cx, |ep_store, cx| {
1091 // third completes and replaces first
1092 assert_eq!(
1093 ep_store
1094 .prediction_at(&buffer, None, &project, cx)
1095 .unwrap()
1096 .id
1097 .0,
1098 third_response_id
1099 );
1100 });
1101
1102 // second is reported as rejected
1103 let (reject_request, _) = requests.reject.next().await.unwrap();
1104
1105 cx.run_until_parked();
1106
1107 assert_eq!(
1108 &reject_request.rejections,
1109 &[
1110 EditPredictionRejection {
1111 request_id: cancelled_id,
1112 reason: EditPredictionRejectReason::Canceled,
1113 was_shown: false
1114 },
1115 EditPredictionRejection {
1116 request_id: first_id,
1117 reason: EditPredictionRejectReason::Replaced,
1118 was_shown: false
1119 }
1120 ]
1121 );
1122}
1123
1124#[gpui::test]
1125async fn test_rejections_flushing(cx: &mut TestAppContext) {
1126 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1127
1128 ep_store.update(cx, |ep_store, _cx| {
1129 ep_store.reject_prediction(
1130 EditPredictionId("test-1".into()),
1131 EditPredictionRejectReason::Discarded,
1132 false,
1133 );
1134 ep_store.reject_prediction(
1135 EditPredictionId("test-2".into()),
1136 EditPredictionRejectReason::Canceled,
1137 true,
1138 );
1139 });
1140
1141 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1142 cx.run_until_parked();
1143
1144 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1145 respond_tx.send(()).unwrap();
1146
1147 // batched
1148 assert_eq!(reject_request.rejections.len(), 2);
1149 assert_eq!(
1150 reject_request.rejections[0],
1151 EditPredictionRejection {
1152 request_id: "test-1".to_string(),
1153 reason: EditPredictionRejectReason::Discarded,
1154 was_shown: false
1155 }
1156 );
1157 assert_eq!(
1158 reject_request.rejections[1],
1159 EditPredictionRejection {
1160 request_id: "test-2".to_string(),
1161 reason: EditPredictionRejectReason::Canceled,
1162 was_shown: true
1163 }
1164 );
1165
1166 // Reaching batch size limit sends without debounce
1167 ep_store.update(cx, |ep_store, _cx| {
1168 for i in 0..70 {
1169 ep_store.reject_prediction(
1170 EditPredictionId(format!("batch-{}", i).into()),
1171 EditPredictionRejectReason::Discarded,
1172 false,
1173 );
1174 }
1175 });
1176
1177 // First MAX/2 items are sent immediately
1178 cx.run_until_parked();
1179 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1180 respond_tx.send(()).unwrap();
1181
1182 assert_eq!(reject_request.rejections.len(), 50);
1183 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1184 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1185
1186 // Remaining items are debounced with the next batch
1187 cx.executor().advance_clock(Duration::from_secs(15));
1188 cx.run_until_parked();
1189
1190 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1191 respond_tx.send(()).unwrap();
1192
1193 assert_eq!(reject_request.rejections.len(), 20);
1194 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1195 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1196
1197 // Request failure
1198 ep_store.update(cx, |ep_store, _cx| {
1199 ep_store.reject_prediction(
1200 EditPredictionId("retry-1".into()),
1201 EditPredictionRejectReason::Discarded,
1202 false,
1203 );
1204 });
1205
1206 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1207 cx.run_until_parked();
1208
1209 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1210 assert_eq!(reject_request.rejections.len(), 1);
1211 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1212 // Simulate failure
1213 drop(_respond_tx);
1214
1215 // Add another rejection
1216 ep_store.update(cx, |ep_store, _cx| {
1217 ep_store.reject_prediction(
1218 EditPredictionId("retry-2".into()),
1219 EditPredictionRejectReason::Discarded,
1220 false,
1221 );
1222 });
1223
1224 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1225 cx.run_until_parked();
1226
1227 // Retry should include both the failed item and the new one
1228 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1229 respond_tx.send(()).unwrap();
1230
1231 assert_eq!(reject_request.rejections.len(), 2);
1232 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1233 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1234}
1235
1236// Skipped until we start including diagnostics in prompt
1237// #[gpui::test]
1238// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1239// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1240// let fs = FakeFs::new(cx.executor());
1241// fs.insert_tree(
1242// "/root",
1243// json!({
1244// "foo.md": "Hello!\nBye"
1245// }),
1246// )
1247// .await;
1248// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1249
1250// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1251// let diagnostic = lsp::Diagnostic {
1252// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1253// severity: Some(lsp::DiagnosticSeverity::ERROR),
1254// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1255// ..Default::default()
1256// };
1257
1258// project.update(cx, |project, cx| {
1259// project.lsp_store().update(cx, |lsp_store, cx| {
1260// // Create some diagnostics
1261// lsp_store
1262// .update_diagnostics(
1263// LanguageServerId(0),
1264// lsp::PublishDiagnosticsParams {
1265// uri: path_to_buffer_uri.clone(),
1266// diagnostics: vec![diagnostic],
1267// version: None,
1268// },
1269// None,
1270// language::DiagnosticSourceKind::Pushed,
1271// &[],
1272// cx,
1273// )
1274// .unwrap();
1275// });
1276// });
1277
1278// let buffer = project
1279// .update(cx, |project, cx| {
1280// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1281// project.open_buffer(path, cx)
1282// })
1283// .await
1284// .unwrap();
1285
1286// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1287// let position = snapshot.anchor_before(language::Point::new(0, 0));
1288
1289// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1290// ep_store.request_prediction(&project, &buffer, position, cx)
1291// });
1292
1293// let (request, _respond_tx) = req_rx.next().await.unwrap();
1294
1295// assert_eq!(request.diagnostic_groups.len(), 1);
1296// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1297// .unwrap();
1298// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1299// assert_eq!(
1300// value,
1301// json!({
1302// "entries": [{
1303// "range": {
1304// "start": 8,
1305// "end": 10
1306// },
1307// "diagnostic": {
1308// "source": null,
1309// "code": null,
1310// "code_description": null,
1311// "severity": 1,
1312// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1313// "markdown": null,
1314// "group_id": 0,
1315// "is_primary": true,
1316// "is_disk_based": false,
1317// "is_unnecessary": false,
1318// "source_kind": "Pushed",
1319// "data": null,
1320// "underline": true
1321// }
1322// }],
1323// "primary_ix": 0
1324// })
1325// );
1326// }
1327
1328// Generate a model response that would apply the given diff to the active file.
1329fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1330 let excerpt =
1331 request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1332 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1333
1334 PredictEditsV3Response {
1335 request_id: Uuid::new_v4().to_string(),
1336 output: new_excerpt,
1337 }
1338}
1339
1340fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1341 zeta_prompt::format_zeta_prompt(&request.input, request.prompt_version)
1342}
1343
1344struct RequestChannels {
1345 predict: mpsc::UnboundedReceiver<(
1346 PredictEditsV3Request,
1347 oneshot::Sender<PredictEditsV3Response>,
1348 )>,
1349 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1350}
1351
1352fn init_test_with_fake_client(
1353 cx: &mut TestAppContext,
1354) -> (Entity<EditPredictionStore>, RequestChannels) {
1355 cx.update(move |cx| {
1356 let settings_store = SettingsStore::test(cx);
1357 cx.set_global(settings_store);
1358 zlog::init_test();
1359
1360 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1361 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1362
1363 let http_client = FakeHttpClient::create({
1364 move |req| {
1365 let uri = req.uri().path().to_string();
1366 let mut body = req.into_body();
1367 let predict_req_tx = predict_req_tx.clone();
1368 let reject_req_tx = reject_req_tx.clone();
1369 async move {
1370 let resp = match uri.as_str() {
1371 "/client/llm_tokens" => serde_json::to_string(&json!({
1372 "token": "test"
1373 }))
1374 .unwrap(),
1375 "/predict_edits/v3" => {
1376 let mut buf = Vec::new();
1377 body.read_to_end(&mut buf).await.ok();
1378 let req = serde_json::from_slice(&buf).unwrap();
1379
1380 let (res_tx, res_rx) = oneshot::channel();
1381 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1382 serde_json::to_string(&res_rx.await?).unwrap()
1383 }
1384 "/predict_edits/reject" => {
1385 let mut buf = Vec::new();
1386 body.read_to_end(&mut buf).await.ok();
1387 let req = serde_json::from_slice(&buf).unwrap();
1388
1389 let (res_tx, res_rx) = oneshot::channel();
1390 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1391 serde_json::to_string(&res_rx.await?).unwrap()
1392 }
1393 _ => {
1394 panic!("Unexpected path: {}", uri)
1395 }
1396 };
1397
1398 Ok(Response::builder().body(resp.into()).unwrap())
1399 }
1400 }
1401 });
1402
1403 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1404 client.cloud_client().set_credentials(1, "test".into());
1405
1406 language_model::init(client.clone(), cx);
1407
1408 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1409 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1410
1411 (
1412 ep_store,
1413 RequestChannels {
1414 predict: predict_req_rx,
1415 reject: reject_req_rx,
1416 },
1417 )
1418 })
1419}
1420
1421const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1422
1423#[gpui::test]
1424async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1425 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1426 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1427 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1428 });
1429
1430 let edit_preview = cx
1431 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1432 .await;
1433
1434 let prediction = EditPrediction {
1435 edits,
1436 cursor_position: None,
1437 edit_preview,
1438 buffer: buffer.clone(),
1439 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1440 id: EditPredictionId("the-id".into()),
1441 inputs: ZetaPromptInput {
1442 events: Default::default(),
1443 related_files: Default::default(),
1444 cursor_path: Path::new("").into(),
1445 cursor_excerpt: "".into(),
1446 editable_range_in_excerpt: 0..0,
1447 cursor_offset_in_excerpt: 0,
1448 excerpt_start_row: None,
1449 },
1450 buffer_snapshotted_at: Instant::now(),
1451 response_received_at: Instant::now(),
1452 };
1453
1454 cx.update(|cx| {
1455 assert_eq!(
1456 from_completion_edits(
1457 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1458 &buffer,
1459 cx
1460 ),
1461 vec![(2..5, "REM".into()), (9..11, "".into())]
1462 );
1463
1464 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1465 assert_eq!(
1466 from_completion_edits(
1467 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1468 &buffer,
1469 cx
1470 ),
1471 vec![(2..2, "REM".into()), (6..8, "".into())]
1472 );
1473
1474 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1475 assert_eq!(
1476 from_completion_edits(
1477 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1478 &buffer,
1479 cx
1480 ),
1481 vec![(2..5, "REM".into()), (9..11, "".into())]
1482 );
1483
1484 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1485 assert_eq!(
1486 from_completion_edits(
1487 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1488 &buffer,
1489 cx
1490 ),
1491 vec![(3..3, "EM".into()), (7..9, "".into())]
1492 );
1493
1494 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1495 assert_eq!(
1496 from_completion_edits(
1497 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1498 &buffer,
1499 cx
1500 ),
1501 vec![(4..4, "M".into()), (8..10, "".into())]
1502 );
1503
1504 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1505 assert_eq!(
1506 from_completion_edits(
1507 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1508 &buffer,
1509 cx
1510 ),
1511 vec![(9..11, "".into())]
1512 );
1513
1514 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1515 assert_eq!(
1516 from_completion_edits(
1517 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1518 &buffer,
1519 cx
1520 ),
1521 vec![(4..4, "M".into()), (8..10, "".into())]
1522 );
1523
1524 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1525 assert_eq!(
1526 from_completion_edits(
1527 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1528 &buffer,
1529 cx
1530 ),
1531 vec![(4..4, "M".into())]
1532 );
1533
1534 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1535 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1536 })
1537}
1538
1539#[gpui::test]
1540async fn test_clean_up_diff(cx: &mut TestAppContext) {
1541 init_test(cx);
1542
1543 assert_eq!(
1544 apply_edit_prediction(
1545 indoc! {"
1546 fn main() {
1547 let word_1 = \"lorem\";
1548 let range = word.len()..word.len();
1549 }
1550 "},
1551 indoc! {"
1552 <|editable_region_start|>
1553 fn main() {
1554 let word_1 = \"lorem\";
1555 let range = word_1.len()..word_1.len();
1556 }
1557
1558 <|editable_region_end|>
1559 "},
1560 cx,
1561 )
1562 .await,
1563 indoc! {"
1564 fn main() {
1565 let word_1 = \"lorem\";
1566 let range = word_1.len()..word_1.len();
1567 }
1568 "},
1569 );
1570
1571 assert_eq!(
1572 apply_edit_prediction(
1573 indoc! {"
1574 fn main() {
1575 let story = \"the quick\"
1576 }
1577 "},
1578 indoc! {"
1579 <|editable_region_start|>
1580 fn main() {
1581 let story = \"the quick brown fox jumps over the lazy dog\";
1582 }
1583
1584 <|editable_region_end|>
1585 "},
1586 cx,
1587 )
1588 .await,
1589 indoc! {"
1590 fn main() {
1591 let story = \"the quick brown fox jumps over the lazy dog\";
1592 }
1593 "},
1594 );
1595}
1596
1597#[gpui::test]
1598async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1599 init_test(cx);
1600
1601 let buffer_content = "lorem\n";
1602 let completion_response = indoc! {"
1603 ```animals.js
1604 <|start_of_file|>
1605 <|editable_region_start|>
1606 lorem
1607 ipsum
1608 <|editable_region_end|>
1609 ```"};
1610
1611 assert_eq!(
1612 apply_edit_prediction(buffer_content, completion_response, cx).await,
1613 "lorem\nipsum"
1614 );
1615}
1616
1617#[gpui::test]
1618async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1619 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1620 // When the buffer ends without a trailing newline, but the model returns output
1621 // with a trailing newline, zeta2 should normalize both sides before diffing
1622 // so no spurious newline is inserted.
1623 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1624 let fs = FakeFs::new(cx.executor());
1625
1626 // Single line buffer with no trailing newline
1627 fs.insert_tree(
1628 "/root",
1629 json!({
1630 "foo.txt": "hello"
1631 }),
1632 )
1633 .await;
1634 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1635
1636 let buffer = project
1637 .update(cx, |project, cx| {
1638 let path = project
1639 .find_project_path(path!("root/foo.txt"), cx)
1640 .unwrap();
1641 project.open_buffer(path, cx)
1642 })
1643 .await
1644 .unwrap();
1645
1646 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1647 let position = snapshot.anchor_before(language::Point::new(0, 5));
1648
1649 ep_store.update(cx, |ep_store, cx| {
1650 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1651 });
1652
1653 let (_request, respond_tx) = requests.predict.next().await.unwrap();
1654
1655 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1656 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1657 let response = PredictEditsV3Response {
1658 request_id: Uuid::new_v4().to_string(),
1659 output: "hello world\n".to_string(),
1660 };
1661 respond_tx.send(response).unwrap();
1662
1663 cx.run_until_parked();
1664
1665 // The prediction should insert " world" without adding a newline
1666 ep_store.update(cx, |ep_store, cx| {
1667 let prediction = ep_store
1668 .prediction_at(&buffer, None, &project, cx)
1669 .expect("should have prediction");
1670 let edits: Vec<_> = prediction
1671 .edits
1672 .iter()
1673 .map(|(range, text)| {
1674 let snapshot = buffer.read(cx).snapshot();
1675 (range.to_offset(&snapshot), text.clone())
1676 })
1677 .collect();
1678 assert_eq!(edits, vec![(5..5, " world".into())]);
1679 });
1680}
1681
1682#[gpui::test]
1683async fn test_can_collect_data(cx: &mut TestAppContext) {
1684 init_test(cx);
1685
1686 let fs = project::FakeFs::new(cx.executor());
1687 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1688 .await;
1689
1690 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1691 let buffer = project
1692 .update(cx, |project, cx| {
1693 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1694 })
1695 .await
1696 .unwrap();
1697
1698 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1699 ep_store.update(cx, |ep_store, _cx| {
1700 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1701 });
1702
1703 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1704 assert_eq!(
1705 captured_request.lock().clone().unwrap().can_collect_data,
1706 true
1707 );
1708
1709 ep_store.update(cx, |ep_store, _cx| {
1710 ep_store.data_collection_choice = DataCollectionChoice::Disabled
1711 });
1712
1713 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1714 assert_eq!(
1715 captured_request.lock().clone().unwrap().can_collect_data,
1716 false
1717 );
1718}
1719
1720#[gpui::test]
1721async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1722 init_test(cx);
1723
1724 let fs = project::FakeFs::new(cx.executor());
1725 let project = Project::test(fs.clone(), [], cx).await;
1726
1727 let buffer = cx.new(|_cx| {
1728 Buffer::remote(
1729 language::BufferId::new(1).unwrap(),
1730 ReplicaId::new(1),
1731 language::Capability::ReadWrite,
1732 "fn main() {\n println!(\"Hello\");\n}",
1733 )
1734 });
1735
1736 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1737 ep_store.update(cx, |ep_store, _cx| {
1738 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1739 });
1740
1741 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1742 assert_eq!(
1743 captured_request.lock().clone().unwrap().can_collect_data,
1744 false
1745 );
1746}
1747
1748#[gpui::test]
1749async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1750 init_test(cx);
1751
1752 let fs = project::FakeFs::new(cx.executor());
1753 fs.insert_tree(
1754 path!("/project"),
1755 json!({
1756 "LICENSE": BSD_0_TXT,
1757 ".env": "SECRET_KEY=secret"
1758 }),
1759 )
1760 .await;
1761
1762 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1763 let buffer = project
1764 .update(cx, |project, cx| {
1765 project.open_local_buffer("/project/.env", cx)
1766 })
1767 .await
1768 .unwrap();
1769
1770 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1771 ep_store.update(cx, |ep_store, _cx| {
1772 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1773 });
1774
1775 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1776 assert_eq!(
1777 captured_request.lock().clone().unwrap().can_collect_data,
1778 false
1779 );
1780}
1781
1782#[gpui::test]
1783async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1784 init_test(cx);
1785
1786 let fs = project::FakeFs::new(cx.executor());
1787 let project = Project::test(fs.clone(), [], cx).await;
1788 let buffer = cx.new(|cx| Buffer::local("", cx));
1789
1790 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1791 ep_store.update(cx, |ep_store, _cx| {
1792 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1793 });
1794
1795 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1796 assert_eq!(
1797 captured_request.lock().clone().unwrap().can_collect_data,
1798 false
1799 );
1800}
1801
1802#[gpui::test]
1803async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1804 init_test(cx);
1805
1806 let fs = project::FakeFs::new(cx.executor());
1807 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1808 .await;
1809
1810 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1811 let buffer = project
1812 .update(cx, |project, cx| {
1813 project.open_local_buffer("/project/main.rs", cx)
1814 })
1815 .await
1816 .unwrap();
1817
1818 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1819 ep_store.update(cx, |ep_store, _cx| {
1820 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1821 });
1822
1823 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1824 assert_eq!(
1825 captured_request.lock().clone().unwrap().can_collect_data,
1826 false
1827 );
1828}
1829
1830#[gpui::test]
1831async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1832 init_test(cx);
1833
1834 let fs = project::FakeFs::new(cx.executor());
1835 fs.insert_tree(
1836 path!("/open_source_worktree"),
1837 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1838 )
1839 .await;
1840 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1841 .await;
1842
1843 let project = Project::test(
1844 fs.clone(),
1845 [
1846 path!("/open_source_worktree").as_ref(),
1847 path!("/closed_source_worktree").as_ref(),
1848 ],
1849 cx,
1850 )
1851 .await;
1852 let buffer = project
1853 .update(cx, |project, cx| {
1854 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1855 })
1856 .await
1857 .unwrap();
1858
1859 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1860 ep_store.update(cx, |ep_store, _cx| {
1861 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1862 });
1863
1864 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1865 assert_eq!(
1866 captured_request.lock().clone().unwrap().can_collect_data,
1867 true
1868 );
1869
1870 let closed_source_file = project
1871 .update(cx, |project, cx| {
1872 let worktree2 = project
1873 .worktree_for_root_name("closed_source_worktree", cx)
1874 .unwrap();
1875 worktree2.update(cx, |worktree2, cx| {
1876 worktree2.load_file(rel_path("main.rs"), cx)
1877 })
1878 })
1879 .await
1880 .unwrap()
1881 .file;
1882
1883 buffer.update(cx, |buffer, cx| {
1884 buffer.file_updated(closed_source_file, cx);
1885 });
1886
1887 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1888 assert_eq!(
1889 captured_request.lock().clone().unwrap().can_collect_data,
1890 false
1891 );
1892}
1893
1894#[gpui::test]
1895async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1896 init_test(cx);
1897
1898 let fs = project::FakeFs::new(cx.executor());
1899 fs.insert_tree(
1900 path!("/worktree1"),
1901 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1902 )
1903 .await;
1904 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1905 .await;
1906
1907 let project = Project::test(
1908 fs.clone(),
1909 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1910 cx,
1911 )
1912 .await;
1913 let buffer = project
1914 .update(cx, |project, cx| {
1915 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1916 })
1917 .await
1918 .unwrap();
1919 let private_buffer = project
1920 .update(cx, |project, cx| {
1921 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1922 })
1923 .await
1924 .unwrap();
1925
1926 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1927 ep_store.update(cx, |ep_store, _cx| {
1928 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1929 });
1930
1931 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1932 assert_eq!(
1933 captured_request.lock().clone().unwrap().can_collect_data,
1934 true
1935 );
1936
1937 // this has a side effect of registering the buffer to watch for edits
1938 run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1939 assert_eq!(
1940 captured_request.lock().clone().unwrap().can_collect_data,
1941 false
1942 );
1943
1944 private_buffer.update(cx, |private_buffer, cx| {
1945 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1946 });
1947
1948 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1949 assert_eq!(
1950 captured_request.lock().clone().unwrap().can_collect_data,
1951 false
1952 );
1953
1954 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1955 // included
1956 buffer.update(cx, |buffer, cx| {
1957 buffer.edit(
1958 [(
1959 0..0,
1960 " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1961 )],
1962 None,
1963 cx,
1964 );
1965 });
1966
1967 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1968 assert_eq!(
1969 captured_request.lock().clone().unwrap().can_collect_data,
1970 true
1971 );
1972}
1973
1974fn init_test(cx: &mut TestAppContext) {
1975 cx.update(|cx| {
1976 let settings_store = SettingsStore::test(cx);
1977 cx.set_global(settings_store);
1978 });
1979}
1980
1981async fn apply_edit_prediction(
1982 buffer_content: &str,
1983 completion_response: &str,
1984 cx: &mut TestAppContext,
1985) -> String {
1986 let fs = project::FakeFs::new(cx.executor());
1987 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1988 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1989 let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1990 *response.lock() = completion_response.to_string();
1991 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1992 buffer.update(cx, |buffer, cx| {
1993 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1994 });
1995 buffer.read_with(cx, |buffer, _| buffer.text())
1996}
1997
1998async fn run_edit_prediction(
1999 buffer: &Entity<Buffer>,
2000 project: &Entity<Project>,
2001 ep_store: &Entity<EditPredictionStore>,
2002 cx: &mut TestAppContext,
2003) -> EditPrediction {
2004 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2005 ep_store.update(cx, |ep_store, cx| {
2006 ep_store.register_buffer(buffer, &project, cx)
2007 });
2008 cx.background_executor.run_until_parked();
2009 let prediction_task = ep_store.update(cx, |ep_store, cx| {
2010 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2011 });
2012 prediction_task.await.unwrap().unwrap().prediction.unwrap()
2013}
2014
2015async fn make_test_ep_store(
2016 project: &Entity<Project>,
2017 cx: &mut TestAppContext,
2018) -> (
2019 Entity<EditPredictionStore>,
2020 Arc<Mutex<Option<PredictEditsBody>>>,
2021 Arc<Mutex<String>>,
2022) {
2023 let default_response = indoc! {"
2024 ```main.rs
2025 <|start_of_file|>
2026 <|editable_region_start|>
2027 hello world
2028 <|editable_region_end|>
2029 ```"
2030 };
2031 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2032 let completion_response: Arc<Mutex<String>> =
2033 Arc::new(Mutex::new(default_response.to_string()));
2034 let http_client = FakeHttpClient::create({
2035 let captured_request = captured_request.clone();
2036 let completion_response = completion_response.clone();
2037 let mut next_request_id = 0;
2038 move |req| {
2039 let captured_request = captured_request.clone();
2040 let completion_response = completion_response.clone();
2041 async move {
2042 match (req.method(), req.uri().path()) {
2043 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2044 .status(200)
2045 .body(
2046 serde_json::to_string(&CreateLlmTokenResponse {
2047 token: LlmToken("the-llm-token".to_string()),
2048 })
2049 .unwrap()
2050 .into(),
2051 )
2052 .unwrap()),
2053 (&Method::POST, "/predict_edits/v2") => {
2054 let mut request_body = String::new();
2055 req.into_body().read_to_string(&mut request_body).await?;
2056 *captured_request.lock() =
2057 Some(serde_json::from_str(&request_body).unwrap());
2058 next_request_id += 1;
2059 Ok(http_client::Response::builder()
2060 .status(200)
2061 .body(
2062 serde_json::to_string(&PredictEditsResponse {
2063 request_id: format!("request-{next_request_id}"),
2064 output_excerpt: completion_response.lock().clone(),
2065 })
2066 .unwrap()
2067 .into(),
2068 )
2069 .unwrap())
2070 }
2071 _ => Ok(http_client::Response::builder()
2072 .status(404)
2073 .body("Not Found".into())
2074 .unwrap()),
2075 }
2076 }
2077 }
2078 });
2079
2080 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2081 cx.update(|cx| {
2082 RefreshLlmTokenListener::register(client.clone(), cx);
2083 });
2084 let _server = FakeServer::for_client(42, &client, cx).await;
2085
2086 let ep_store = cx.new(|cx| {
2087 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2088 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2089
2090 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2091 for worktree in worktrees {
2092 let worktree_id = worktree.read(cx).id();
2093 ep_store
2094 .get_or_init_project(project, cx)
2095 .license_detection_watchers
2096 .entry(worktree_id)
2097 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2098 }
2099
2100 ep_store
2101 });
2102
2103 (ep_store, captured_request, completion_response)
2104}
2105
2106fn to_completion_edits(
2107 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2108 buffer: &Entity<Buffer>,
2109 cx: &App,
2110) -> Vec<(Range<Anchor>, Arc<str>)> {
2111 let buffer = buffer.read(cx);
2112 iterator
2113 .into_iter()
2114 .map(|(range, text)| {
2115 (
2116 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2117 text,
2118 )
2119 })
2120 .collect()
2121}
2122
2123fn from_completion_edits(
2124 editor_edits: &[(Range<Anchor>, Arc<str>)],
2125 buffer: &Entity<Buffer>,
2126 cx: &App,
2127) -> Vec<(Range<usize>, Arc<str>)> {
2128 let buffer = buffer.read(cx);
2129 editor_edits
2130 .iter()
2131 .map(|(range, text)| {
2132 (
2133 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2134 text.clone(),
2135 )
2136 })
2137 .collect()
2138}
2139
2140#[gpui::test]
2141async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2142 init_test(cx);
2143
2144 let fs = FakeFs::new(cx.executor());
2145 fs.insert_tree(
2146 "/project",
2147 serde_json::json!({
2148 "main.rs": "fn main() {\n \n}\n"
2149 }),
2150 )
2151 .await;
2152
2153 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2154
2155 let http_client = FakeHttpClient::create(|_req| async move {
2156 Ok(gpui::http_client::Response::builder()
2157 .status(401)
2158 .body("Unauthorized".into())
2159 .unwrap())
2160 });
2161
2162 let client =
2163 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2164 cx.update(|cx| {
2165 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2166 });
2167
2168 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2169
2170 let buffer = project
2171 .update(cx, |project, cx| {
2172 let path = project
2173 .find_project_path(path!("/project/main.rs"), cx)
2174 .unwrap();
2175 project.open_buffer(path, cx)
2176 })
2177 .await
2178 .unwrap();
2179
2180 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2181 ep_store.update(cx, |ep_store, cx| {
2182 ep_store.register_buffer(&buffer, &project, cx)
2183 });
2184 cx.background_executor.run_until_parked();
2185
2186 let completion_task = ep_store.update(cx, |ep_store, cx| {
2187 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2188 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2189 });
2190
2191 let result = completion_task.await;
2192 assert!(
2193 result.is_err(),
2194 "Without authentication and without custom URL, prediction should fail"
2195 );
2196}
2197
2198#[gpui::test]
2199async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2200 init_test(cx);
2201
2202 let fs = FakeFs::new(cx.executor());
2203 fs.insert_tree(
2204 "/project",
2205 serde_json::json!({
2206 "main.rs": "fn main() {\n \n}\n"
2207 }),
2208 )
2209 .await;
2210
2211 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2212
2213 let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2214 let predict_called_clone = predict_called.clone();
2215
2216 let http_client = FakeHttpClient::create({
2217 move |req| {
2218 let uri = req.uri().path().to_string();
2219 let predict_called = predict_called_clone.clone();
2220 async move {
2221 if uri.contains("predict") {
2222 predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2223 Ok(gpui::http_client::Response::builder()
2224 .body(
2225 serde_json::to_string(&open_ai::Response {
2226 id: "test-123".to_string(),
2227 object: "chat.completion".to_string(),
2228 created: 0,
2229 model: "test".to_string(),
2230 usage: open_ai::Usage {
2231 prompt_tokens: 0,
2232 completion_tokens: 0,
2233 total_tokens: 0,
2234 },
2235 choices: vec![open_ai::Choice {
2236 index: 0,
2237 message: open_ai::RequestMessage::Assistant {
2238 content: Some(open_ai::MessageContent::Plain(
2239 indoc! {"
2240 ```main.rs
2241 <|start_of_file|>
2242 <|editable_region_start|>
2243 fn main() {
2244 println!(\"Hello, world!\");
2245 }
2246 <|editable_region_end|>
2247 ```
2248 "}
2249 .to_string(),
2250 )),
2251 tool_calls: vec![],
2252 },
2253 finish_reason: Some("stop".to_string()),
2254 }],
2255 })
2256 .unwrap()
2257 .into(),
2258 )
2259 .unwrap())
2260 } else {
2261 Ok(gpui::http_client::Response::builder()
2262 .status(401)
2263 .body("Unauthorized".into())
2264 .unwrap())
2265 }
2266 }
2267 }
2268 });
2269
2270 let client =
2271 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2272 cx.update(|cx| {
2273 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2274 });
2275
2276 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2277
2278 let buffer = project
2279 .update(cx, |project, cx| {
2280 let path = project
2281 .find_project_path(path!("/project/main.rs"), cx)
2282 .unwrap();
2283 project.open_buffer(path, cx)
2284 })
2285 .await
2286 .unwrap();
2287
2288 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2289 ep_store.update(cx, |ep_store, cx| {
2290 ep_store.register_buffer(&buffer, &project, cx)
2291 });
2292 cx.background_executor.run_until_parked();
2293
2294 let completion_task = ep_store.update(cx, |ep_store, cx| {
2295 ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2296 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2297 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2298 });
2299
2300 let _ = completion_task.await;
2301
2302 assert!(
2303 predict_called.load(std::sync::atomic::Ordering::SeqCst),
2304 "With custom URL, predict endpoint should be called even without authentication"
2305 );
2306}
2307
2308#[gpui::test]
2309fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2310 let buffer = cx.new(|cx| {
2311 Buffer::local(
2312 indoc! {"
2313 zero
2314 one
2315 two
2316 three
2317 four
2318 five
2319 six
2320 seven
2321 eight
2322 nine
2323 ten
2324 eleven
2325 twelve
2326 thirteen
2327 fourteen
2328 fifteen
2329 sixteen
2330 seventeen
2331 eighteen
2332 nineteen
2333 twenty
2334 twenty-one
2335 twenty-two
2336 twenty-three
2337 twenty-four
2338 "},
2339 cx,
2340 )
2341 });
2342
2343 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2344
2345 buffer.update(cx, |buffer, cx| {
2346 let point = Point::new(12, 0);
2347 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2348 let point = Point::new(8, 0);
2349 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2350 });
2351
2352 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2353
2354 let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2355
2356 assert_eq!(
2357 diff,
2358 indoc! {"
2359 @@ -6,10 +6,12 @@
2360 five
2361 six
2362 seven
2363 +FIRST INSERTION
2364 eight
2365 nine
2366 ten
2367 eleven
2368 +SECOND INSERTION
2369 twelve
2370 thirteen
2371 fourteen
2372 "}
2373 );
2374}
2375
2376#[ctor::ctor]
2377fn init_logger() {
2378 zlog::init_test();
2379}