1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
3use client::{UserStore, test::FakeServer};
4use clock::{FakeSystemClock, ReplicaId};
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
8 RejectEditPredictionsBody,
9 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
10};
11use futures::{
12 AsyncReadExt, StreamExt,
13 channel::{mpsc, oneshot},
14};
15use gpui::App;
16use gpui::{
17 Entity, TestAppContext,
18 http_client::{FakeHttpClient, Response},
19};
20use indoc::indoc;
21use language::{Buffer, Point};
22use lsp::LanguageServerId;
23use parking_lot::Mutex;
24use pretty_assertions::{assert_eq, assert_matches};
25use project::{FakeFs, Project};
26use serde_json::json;
27use settings::SettingsStore;
28use std::{path::Path, sync::Arc, time::Duration};
29use util::{path, rel_path::rel_path};
30use uuid::Uuid;
31use zeta_prompt::ZetaPromptInput;
32
33use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
34
35#[gpui::test]
36async fn test_current_state(cx: &mut TestAppContext) {
37 let (ep_store, mut requests) = init_test_with_fake_client(cx);
38 let fs = FakeFs::new(cx.executor());
39 fs.insert_tree(
40 "/root",
41 json!({
42 "1.txt": "Hello!\nHow\nBye\n",
43 "2.txt": "Hola!\nComo\nAdios\n"
44 }),
45 )
46 .await;
47 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
48
49 let buffer1 = project
50 .update(cx, |project, cx| {
51 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
52 project.set_active_path(Some(path.clone()), cx);
53 project.open_buffer(path, cx)
54 })
55 .await
56 .unwrap();
57 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
58 let position = snapshot1.anchor_before(language::Point::new(1, 3));
59
60 ep_store.update(cx, |ep_store, cx| {
61 ep_store.register_project(&project, cx);
62 ep_store.register_buffer(&buffer1, &project, cx);
63 });
64
65 // Prediction for current file
66
67 ep_store.update(cx, |ep_store, cx| {
68 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
69 });
70 let (request, respond_tx) = requests.predict.next().await.unwrap();
71
72 respond_tx
73 .send(model_response(
74 &request,
75 indoc! {r"
76 --- a/root/1.txt
77 +++ b/root/1.txt
78 @@ ... @@
79 Hello!
80 -How
81 +How are you?
82 Bye
83 "},
84 ))
85 .unwrap();
86
87 cx.run_until_parked();
88
89 ep_store.update(cx, |ep_store, cx| {
90 let prediction = ep_store
91 .prediction_at(&buffer1, None, &project, cx)
92 .unwrap();
93 assert_matches!(prediction, BufferEditPrediction::Local { .. });
94 });
95
96 ep_store.update(cx, |ep_store, cx| {
97 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
98 });
99
100 // Prediction for diagnostic in another file
101
102 let diagnostic = lsp::Diagnostic {
103 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
104 severity: Some(lsp::DiagnosticSeverity::ERROR),
105 message: "Sentence is incomplete".to_string(),
106 ..Default::default()
107 };
108
109 project.update(cx, |project, cx| {
110 project.lsp_store().update(cx, |lsp_store, cx| {
111 lsp_store
112 .update_diagnostics(
113 LanguageServerId(0),
114 lsp::PublishDiagnosticsParams {
115 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
116 diagnostics: vec![diagnostic],
117 version: None,
118 },
119 None,
120 language::DiagnosticSourceKind::Pushed,
121 &[],
122 cx,
123 )
124 .unwrap();
125 });
126 });
127
128 let (request, respond_tx) = requests.predict.next().await.unwrap();
129 respond_tx
130 .send(model_response(
131 &request,
132 indoc! {r#"
133 --- a/root/2.txt
134 +++ b/root/2.txt
135 @@ ... @@
136 Hola!
137 -Como
138 +Como estas?
139 Adios
140 "#},
141 ))
142 .unwrap();
143 cx.run_until_parked();
144
145 ep_store.update(cx, |ep_store, cx| {
146 let prediction = ep_store
147 .prediction_at(&buffer1, None, &project, cx)
148 .unwrap();
149 assert_matches!(
150 prediction,
151 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
152 );
153 });
154
155 let buffer2 = project
156 .update(cx, |project, cx| {
157 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
158 project.open_buffer(path, cx)
159 })
160 .await
161 .unwrap();
162
163 ep_store.update(cx, |ep_store, cx| {
164 let prediction = ep_store
165 .prediction_at(&buffer2, None, &project, cx)
166 .unwrap();
167 assert_matches!(prediction, BufferEditPrediction::Local { .. });
168 });
169}
170
171#[gpui::test]
172async fn test_simple_request(cx: &mut TestAppContext) {
173 let (ep_store, mut requests) = init_test_with_fake_client(cx);
174 let fs = FakeFs::new(cx.executor());
175 fs.insert_tree(
176 "/root",
177 json!({
178 "foo.md": "Hello!\nHow\nBye\n"
179 }),
180 )
181 .await;
182 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
183
184 let buffer = project
185 .update(cx, |project, cx| {
186 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
187 project.open_buffer(path, cx)
188 })
189 .await
190 .unwrap();
191 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
192 let position = snapshot.anchor_before(language::Point::new(1, 3));
193
194 let prediction_task = ep_store.update(cx, |ep_store, cx| {
195 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
196 });
197
198 let (request, respond_tx) = requests.predict.next().await.unwrap();
199
200 // TODO Put back when we have a structured request again
201 // assert_eq!(
202 // request.excerpt_path.as_ref(),
203 // Path::new(path!("root/foo.md"))
204 // );
205 // assert_eq!(
206 // request.cursor_point,
207 // Point {
208 // line: Line(1),
209 // column: 3
210 // }
211 // );
212
213 respond_tx
214 .send(model_response(
215 &request,
216 indoc! { r"
217 --- a/root/foo.md
218 +++ b/root/foo.md
219 @@ ... @@
220 Hello!
221 -How
222 +How are you?
223 Bye
224 "},
225 ))
226 .unwrap();
227
228 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
229
230 assert_eq!(prediction.edits.len(), 1);
231 assert_eq!(
232 prediction.edits[0].0.to_point(&snapshot).start,
233 language::Point::new(1, 3)
234 );
235 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
236}
237
238#[gpui::test]
239async fn test_request_events(cx: &mut TestAppContext) {
240 let (ep_store, mut requests) = init_test_with_fake_client(cx);
241 let fs = FakeFs::new(cx.executor());
242 fs.insert_tree(
243 "/root",
244 json!({
245 "foo.md": "Hello!\n\nBye\n"
246 }),
247 )
248 .await;
249 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
250
251 let buffer = project
252 .update(cx, |project, cx| {
253 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
254 project.open_buffer(path, cx)
255 })
256 .await
257 .unwrap();
258
259 ep_store.update(cx, |ep_store, cx| {
260 ep_store.register_buffer(&buffer, &project, cx);
261 });
262
263 buffer.update(cx, |buffer, cx| {
264 buffer.edit(vec![(7..7, "How")], None, cx);
265 });
266
267 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
268 let position = snapshot.anchor_before(language::Point::new(1, 3));
269
270 let prediction_task = ep_store.update(cx, |ep_store, cx| {
271 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
272 });
273
274 let (request, respond_tx) = requests.predict.next().await.unwrap();
275
276 let prompt = prompt_from_request(&request);
277 assert!(
278 prompt.contains(indoc! {"
279 --- a/root/foo.md
280 +++ b/root/foo.md
281 @@ -1,3 +1,3 @@
282 Hello!
283 -
284 +How
285 Bye
286 "}),
287 "{prompt}"
288 );
289
290 respond_tx
291 .send(model_response(
292 &request,
293 indoc! {r#"
294 --- a/root/foo.md
295 +++ b/root/foo.md
296 @@ ... @@
297 Hello!
298 -How
299 +How are you?
300 Bye
301 "#},
302 ))
303 .unwrap();
304
305 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
306
307 assert_eq!(prediction.edits.len(), 1);
308 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
309}
310
311#[gpui::test]
312async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
313 let (ep_store, _requests) = init_test_with_fake_client(cx);
314 let fs = FakeFs::new(cx.executor());
315 fs.insert_tree(
316 "/root",
317 json!({
318 "foo.md": "Hello!\n\nBye\n"
319 }),
320 )
321 .await;
322 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
323
324 let buffer = project
325 .update(cx, |project, cx| {
326 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
327 project.open_buffer(path, cx)
328 })
329 .await
330 .unwrap();
331
332 ep_store.update(cx, |ep_store, cx| {
333 ep_store.register_buffer(&buffer, &project, cx);
334 });
335
336 // First burst: insert "How"
337 buffer.update(cx, |buffer, cx| {
338 buffer.edit(vec![(7..7, "How")], None, cx);
339 });
340
341 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
342 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
343 cx.run_until_parked();
344
345 // Second burst: append " are you?" immediately after "How" on the same line.
346 //
347 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
348 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
349 buffer.update(cx, |buffer, cx| {
350 buffer.edit(vec![(10..10, " are you?")], None, cx);
351 });
352
353 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
354 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
355 buffer.update(cx, |buffer, cx| {
356 buffer.edit(vec![(19..19, "!")], None, cx);
357 });
358
359 // Without time-based splitting, there is one event.
360 let events = ep_store.update(cx, |ep_store, cx| {
361 ep_store.edit_history_for_project(&project, cx)
362 });
363 assert_eq!(events.len(), 1);
364 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
365 assert_eq!(
366 diff.as_str(),
367 indoc! {"
368 @@ -1,3 +1,3 @@
369 Hello!
370 -
371 +How are you?!
372 Bye
373 "}
374 );
375
376 // With time-based splitting, there are two distinct events.
377 let events = ep_store.update(cx, |ep_store, cx| {
378 ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
379 });
380 assert_eq!(events.len(), 2);
381 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
382 assert_eq!(
383 diff.as_str(),
384 indoc! {"
385 @@ -1,3 +1,3 @@
386 Hello!
387 -
388 +How
389 Bye
390 "}
391 );
392
393 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
394 assert_eq!(
395 diff.as_str(),
396 indoc! {"
397 @@ -1,3 +1,3 @@
398 Hello!
399 -How
400 +How are you?!
401 Bye
402 "}
403 );
404}
405
406#[gpui::test]
407async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
408 let (ep_store, _requests) = init_test_with_fake_client(cx);
409 let fs = FakeFs::new(cx.executor());
410
411 // Create a file with 30 lines to test line-based coalescing
412 let content = (1..=30)
413 .map(|i| format!("Line {}\n", i))
414 .collect::<String>();
415 fs.insert_tree(
416 "/root",
417 json!({
418 "foo.md": content
419 }),
420 )
421 .await;
422 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
423
424 let buffer = project
425 .update(cx, |project, cx| {
426 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
427 project.open_buffer(path, cx)
428 })
429 .await
430 .unwrap();
431
432 ep_store.update(cx, |ep_store, cx| {
433 ep_store.register_buffer(&buffer, &project, cx);
434 });
435
436 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
437 buffer.update(cx, |buffer, cx| {
438 let start = Point::new(10, 0).to_offset(buffer);
439 let end = Point::new(13, 0).to_offset(buffer);
440 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
441 });
442
443 let events = ep_store.update(cx, |ep_store, cx| {
444 ep_store.edit_history_for_project(&project, cx)
445 });
446 assert_eq!(
447 render_events(&events),
448 indoc! {"
449 @@ -8,9 +8,8 @@
450 Line 8
451 Line 9
452 Line 10
453 -Line 11
454 -Line 12
455 -Line 13
456 +Middle A
457 +Middle B
458 Line 14
459 Line 15
460 Line 16
461 "},
462 "After first edit"
463 );
464
465 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
466 // This tests that coalescing considers the START of the existing range
467 buffer.update(cx, |buffer, cx| {
468 let offset = Point::new(5, 0).to_offset(buffer);
469 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
470 });
471
472 let events = ep_store.update(cx, |ep_store, cx| {
473 ep_store.edit_history_for_project(&project, cx)
474 });
475 assert_eq!(
476 render_events(&events),
477 indoc! {"
478 @@ -3,14 +3,14 @@
479 Line 3
480 Line 4
481 Line 5
482 +Above
483 Line 6
484 Line 7
485 Line 8
486 Line 9
487 Line 10
488 -Line 11
489 -Line 12
490 -Line 13
491 +Middle A
492 +Middle B
493 Line 14
494 Line 15
495 Line 16
496 "},
497 "After inserting above (should coalesce)"
498 );
499
500 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
501 // This tests that coalescing considers the END of the existing range
502 buffer.update(cx, |buffer, cx| {
503 let offset = Point::new(14, 0).to_offset(buffer);
504 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
505 });
506
507 let events = ep_store.update(cx, |ep_store, cx| {
508 ep_store.edit_history_for_project(&project, cx)
509 });
510 assert_eq!(
511 render_events(&events),
512 indoc! {"
513 @@ -3,15 +3,16 @@
514 Line 3
515 Line 4
516 Line 5
517 +Above
518 Line 6
519 Line 7
520 Line 8
521 Line 9
522 Line 10
523 -Line 11
524 -Line 12
525 -Line 13
526 +Middle A
527 +Middle B
528 Line 14
529 +Below
530 Line 15
531 Line 16
532 Line 17
533 "},
534 "After inserting below (should coalesce)"
535 );
536
537 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
538 // This should NOT coalesce - creates a new event
539 buffer.update(cx, |buffer, cx| {
540 let offset = Point::new(25, 0).to_offset(buffer);
541 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
542 });
543
544 let events = ep_store.update(cx, |ep_store, cx| {
545 ep_store.edit_history_for_project(&project, cx)
546 });
547 assert_eq!(
548 render_events(&events),
549 indoc! {"
550 @@ -3,15 +3,16 @@
551 Line 3
552 Line 4
553 Line 5
554 +Above
555 Line 6
556 Line 7
557 Line 8
558 Line 9
559 Line 10
560 -Line 11
561 -Line 12
562 -Line 13
563 +Middle A
564 +Middle B
565 Line 14
566 +Below
567 Line 15
568 Line 16
569 Line 17
570
571 ---
572 @@ -23,6 +23,7 @@
573 Line 22
574 Line 23
575 Line 24
576 +Far below
577 Line 25
578 Line 26
579 Line 27
580 "},
581 "After inserting far below (should NOT coalesce)"
582 );
583}
584
585fn render_events(events: &[StoredEvent]) -> String {
586 events
587 .iter()
588 .map(|e| {
589 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
590 diff.as_str()
591 })
592 .collect::<Vec<_>>()
593 .join("\n---\n")
594}
595
596#[gpui::test]
597async fn test_empty_prediction(cx: &mut TestAppContext) {
598 let (ep_store, mut requests) = init_test_with_fake_client(cx);
599 let fs = FakeFs::new(cx.executor());
600 fs.insert_tree(
601 "/root",
602 json!({
603 "foo.md": "Hello!\nHow\nBye\n"
604 }),
605 )
606 .await;
607 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
608
609 let buffer = project
610 .update(cx, |project, cx| {
611 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
612 project.open_buffer(path, cx)
613 })
614 .await
615 .unwrap();
616 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
617 let position = snapshot.anchor_before(language::Point::new(1, 3));
618
619 ep_store.update(cx, |ep_store, cx| {
620 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
621 });
622
623 let (request, respond_tx) = requests.predict.next().await.unwrap();
624 let response = model_response(&request, "");
625 let id = response.request_id.clone();
626 respond_tx.send(response).unwrap();
627
628 cx.run_until_parked();
629
630 ep_store.update(cx, |ep_store, cx| {
631 assert!(
632 ep_store
633 .prediction_at(&buffer, None, &project, cx)
634 .is_none()
635 );
636 });
637
638 // prediction is reported as rejected
639 let (reject_request, _) = requests.reject.next().await.unwrap();
640
641 assert_eq!(
642 &reject_request.rejections,
643 &[EditPredictionRejection {
644 request_id: id,
645 reason: EditPredictionRejectReason::Empty,
646 was_shown: false
647 }]
648 );
649}
650
651#[gpui::test]
652async fn test_interpolated_empty(cx: &mut TestAppContext) {
653 let (ep_store, mut requests) = init_test_with_fake_client(cx);
654 let fs = FakeFs::new(cx.executor());
655 fs.insert_tree(
656 "/root",
657 json!({
658 "foo.md": "Hello!\nHow\nBye\n"
659 }),
660 )
661 .await;
662 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
663
664 let buffer = project
665 .update(cx, |project, cx| {
666 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
667 project.open_buffer(path, cx)
668 })
669 .await
670 .unwrap();
671 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
672 let position = snapshot.anchor_before(language::Point::new(1, 3));
673
674 ep_store.update(cx, |ep_store, cx| {
675 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
676 });
677
678 let (request, respond_tx) = requests.predict.next().await.unwrap();
679
680 buffer.update(cx, |buffer, cx| {
681 buffer.set_text("Hello!\nHow are you?\nBye", cx);
682 });
683
684 let response = model_response(&request, SIMPLE_DIFF);
685 let id = response.request_id.clone();
686 respond_tx.send(response).unwrap();
687
688 cx.run_until_parked();
689
690 ep_store.update(cx, |ep_store, cx| {
691 assert!(
692 ep_store
693 .prediction_at(&buffer, None, &project, cx)
694 .is_none()
695 );
696 });
697
698 // prediction is reported as rejected
699 let (reject_request, _) = requests.reject.next().await.unwrap();
700
701 assert_eq!(
702 &reject_request.rejections,
703 &[EditPredictionRejection {
704 request_id: id,
705 reason: EditPredictionRejectReason::InterpolatedEmpty,
706 was_shown: false
707 }]
708 );
709}
710
711const SIMPLE_DIFF: &str = indoc! { r"
712 --- a/root/foo.md
713 +++ b/root/foo.md
714 @@ ... @@
715 Hello!
716 -How
717 +How are you?
718 Bye
719"};
720
721#[gpui::test]
722async fn test_replace_current(cx: &mut TestAppContext) {
723 let (ep_store, mut requests) = init_test_with_fake_client(cx);
724 let fs = FakeFs::new(cx.executor());
725 fs.insert_tree(
726 "/root",
727 json!({
728 "foo.md": "Hello!\nHow\nBye\n"
729 }),
730 )
731 .await;
732 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
733
734 let buffer = project
735 .update(cx, |project, cx| {
736 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
737 project.open_buffer(path, cx)
738 })
739 .await
740 .unwrap();
741 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
742 let position = snapshot.anchor_before(language::Point::new(1, 3));
743
744 ep_store.update(cx, |ep_store, cx| {
745 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
746 });
747
748 let (request, respond_tx) = requests.predict.next().await.unwrap();
749 let first_response = model_response(&request, SIMPLE_DIFF);
750 let first_id = first_response.request_id.clone();
751 respond_tx.send(first_response).unwrap();
752
753 cx.run_until_parked();
754
755 ep_store.update(cx, |ep_store, cx| {
756 assert_eq!(
757 ep_store
758 .prediction_at(&buffer, None, &project, cx)
759 .unwrap()
760 .id
761 .0,
762 first_id
763 );
764 });
765
766 // a second request is triggered
767 ep_store.update(cx, |ep_store, cx| {
768 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
769 });
770
771 let (request, respond_tx) = requests.predict.next().await.unwrap();
772 let second_response = model_response(&request, SIMPLE_DIFF);
773 let second_id = second_response.request_id.clone();
774 respond_tx.send(second_response).unwrap();
775
776 cx.run_until_parked();
777
778 ep_store.update(cx, |ep_store, cx| {
779 // second replaces first
780 assert_eq!(
781 ep_store
782 .prediction_at(&buffer, None, &project, cx)
783 .unwrap()
784 .id
785 .0,
786 second_id
787 );
788 });
789
790 // first is reported as replaced
791 let (reject_request, _) = requests.reject.next().await.unwrap();
792
793 assert_eq!(
794 &reject_request.rejections,
795 &[EditPredictionRejection {
796 request_id: first_id,
797 reason: EditPredictionRejectReason::Replaced,
798 was_shown: false
799 }]
800 );
801}
802
803#[gpui::test]
804async fn test_current_preferred(cx: &mut TestAppContext) {
805 let (ep_store, mut requests) = init_test_with_fake_client(cx);
806 let fs = FakeFs::new(cx.executor());
807 fs.insert_tree(
808 "/root",
809 json!({
810 "foo.md": "Hello!\nHow\nBye\n"
811 }),
812 )
813 .await;
814 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
815
816 let buffer = project
817 .update(cx, |project, cx| {
818 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
819 project.open_buffer(path, cx)
820 })
821 .await
822 .unwrap();
823 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
824 let position = snapshot.anchor_before(language::Point::new(1, 3));
825
826 ep_store.update(cx, |ep_store, cx| {
827 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
828 });
829
830 let (request, respond_tx) = requests.predict.next().await.unwrap();
831 let first_response = model_response(&request, SIMPLE_DIFF);
832 let first_id = first_response.request_id.clone();
833 respond_tx.send(first_response).unwrap();
834
835 cx.run_until_parked();
836
837 ep_store.update(cx, |ep_store, cx| {
838 assert_eq!(
839 ep_store
840 .prediction_at(&buffer, None, &project, cx)
841 .unwrap()
842 .id
843 .0,
844 first_id
845 );
846 });
847
848 // a second request is triggered
849 ep_store.update(cx, |ep_store, cx| {
850 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
851 });
852
853 let (request, respond_tx) = requests.predict.next().await.unwrap();
854 // worse than current prediction
855 let second_response = model_response(
856 &request,
857 indoc! { r"
858 --- a/root/foo.md
859 +++ b/root/foo.md
860 @@ ... @@
861 Hello!
862 -How
863 +How are
864 Bye
865 "},
866 );
867 let second_id = second_response.request_id.clone();
868 respond_tx.send(second_response).unwrap();
869
870 cx.run_until_parked();
871
872 ep_store.update(cx, |ep_store, cx| {
873 // first is preferred over second
874 assert_eq!(
875 ep_store
876 .prediction_at(&buffer, None, &project, cx)
877 .unwrap()
878 .id
879 .0,
880 first_id
881 );
882 });
883
884 // second is reported as rejected
885 let (reject_request, _) = requests.reject.next().await.unwrap();
886
887 assert_eq!(
888 &reject_request.rejections,
889 &[EditPredictionRejection {
890 request_id: second_id,
891 reason: EditPredictionRejectReason::CurrentPreferred,
892 was_shown: false
893 }]
894 );
895}
896
897#[gpui::test]
898async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
899 let (ep_store, mut requests) = init_test_with_fake_client(cx);
900 let fs = FakeFs::new(cx.executor());
901 fs.insert_tree(
902 "/root",
903 json!({
904 "foo.md": "Hello!\nHow\nBye\n"
905 }),
906 )
907 .await;
908 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
909
910 let buffer = project
911 .update(cx, |project, cx| {
912 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
913 project.open_buffer(path, cx)
914 })
915 .await
916 .unwrap();
917 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
918 let position = snapshot.anchor_before(language::Point::new(1, 3));
919
920 // start two refresh tasks
921 ep_store.update(cx, |ep_store, cx| {
922 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
923 });
924
925 let (request1, respond_first) = requests.predict.next().await.unwrap();
926
927 ep_store.update(cx, |ep_store, cx| {
928 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
929 });
930
931 let (request, respond_second) = requests.predict.next().await.unwrap();
932
933 // wait for throttle
934 cx.run_until_parked();
935
936 // second responds first
937 let second_response = model_response(&request, SIMPLE_DIFF);
938 let second_id = second_response.request_id.clone();
939 respond_second.send(second_response).unwrap();
940
941 cx.run_until_parked();
942
943 ep_store.update(cx, |ep_store, cx| {
944 // current prediction is second
945 assert_eq!(
946 ep_store
947 .prediction_at(&buffer, None, &project, cx)
948 .unwrap()
949 .id
950 .0,
951 second_id
952 );
953 });
954
955 let first_response = model_response(&request1, SIMPLE_DIFF);
956 let first_id = first_response.request_id.clone();
957 respond_first.send(first_response).unwrap();
958
959 cx.run_until_parked();
960
961 ep_store.update(cx, |ep_store, cx| {
962 // current prediction is still second, since first was cancelled
963 assert_eq!(
964 ep_store
965 .prediction_at(&buffer, None, &project, cx)
966 .unwrap()
967 .id
968 .0,
969 second_id
970 );
971 });
972
973 // first is reported as rejected
974 let (reject_request, _) = requests.reject.next().await.unwrap();
975
976 cx.run_until_parked();
977
978 assert_eq!(
979 &reject_request.rejections,
980 &[EditPredictionRejection {
981 request_id: first_id,
982 reason: EditPredictionRejectReason::Canceled,
983 was_shown: false
984 }]
985 );
986}
987
988#[gpui::test]
989async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
990 let (ep_store, mut requests) = init_test_with_fake_client(cx);
991 let fs = FakeFs::new(cx.executor());
992 fs.insert_tree(
993 "/root",
994 json!({
995 "foo.md": "Hello!\nHow\nBye\n"
996 }),
997 )
998 .await;
999 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1000
1001 let buffer = project
1002 .update(cx, |project, cx| {
1003 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1004 project.open_buffer(path, cx)
1005 })
1006 .await
1007 .unwrap();
1008 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1009 let position = snapshot.anchor_before(language::Point::new(1, 3));
1010
1011 // start two refresh tasks
1012 ep_store.update(cx, |ep_store, cx| {
1013 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1014 });
1015
1016 let (request1, respond_first) = requests.predict.next().await.unwrap();
1017
1018 ep_store.update(cx, |ep_store, cx| {
1019 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1020 });
1021
1022 let (request2, respond_second) = requests.predict.next().await.unwrap();
1023
1024 // wait for throttle, so requests are sent
1025 cx.run_until_parked();
1026
1027 ep_store.update(cx, |ep_store, cx| {
1028 // start a third request
1029 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1030
1031 // 2 are pending, so 2nd is cancelled
1032 assert_eq!(
1033 ep_store
1034 .get_or_init_project(&project, cx)
1035 .cancelled_predictions
1036 .iter()
1037 .copied()
1038 .collect::<Vec<_>>(),
1039 [1]
1040 );
1041 });
1042
1043 // wait for throttle
1044 cx.run_until_parked();
1045
1046 let (request3, respond_third) = requests.predict.next().await.unwrap();
1047
1048 let first_response = model_response(&request1, SIMPLE_DIFF);
1049 let first_id = first_response.request_id.clone();
1050 respond_first.send(first_response).unwrap();
1051
1052 cx.run_until_parked();
1053
1054 ep_store.update(cx, |ep_store, cx| {
1055 // current prediction is first
1056 assert_eq!(
1057 ep_store
1058 .prediction_at(&buffer, None, &project, cx)
1059 .unwrap()
1060 .id
1061 .0,
1062 first_id
1063 );
1064 });
1065
1066 let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1067 let cancelled_id = cancelled_response.request_id.clone();
1068 respond_second.send(cancelled_response).unwrap();
1069
1070 cx.run_until_parked();
1071
1072 ep_store.update(cx, |ep_store, cx| {
1073 // current prediction is still first, since second was cancelled
1074 assert_eq!(
1075 ep_store
1076 .prediction_at(&buffer, None, &project, cx)
1077 .unwrap()
1078 .id
1079 .0,
1080 first_id
1081 );
1082 });
1083
1084 let third_response = model_response(&request3, SIMPLE_DIFF);
1085 let third_response_id = third_response.request_id.clone();
1086 respond_third.send(third_response).unwrap();
1087
1088 cx.run_until_parked();
1089
1090 ep_store.update(cx, |ep_store, cx| {
1091 // third completes and replaces first
1092 assert_eq!(
1093 ep_store
1094 .prediction_at(&buffer, None, &project, cx)
1095 .unwrap()
1096 .id
1097 .0,
1098 third_response_id
1099 );
1100 });
1101
1102 // second is reported as rejected
1103 let (reject_request, _) = requests.reject.next().await.unwrap();
1104
1105 cx.run_until_parked();
1106
1107 assert_eq!(
1108 &reject_request.rejections,
1109 &[
1110 EditPredictionRejection {
1111 request_id: cancelled_id,
1112 reason: EditPredictionRejectReason::Canceled,
1113 was_shown: false
1114 },
1115 EditPredictionRejection {
1116 request_id: first_id,
1117 reason: EditPredictionRejectReason::Replaced,
1118 was_shown: false
1119 }
1120 ]
1121 );
1122}
1123
1124#[gpui::test]
1125async fn test_rejections_flushing(cx: &mut TestAppContext) {
1126 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1127
1128 ep_store.update(cx, |ep_store, cx| {
1129 ep_store.reject_prediction(
1130 EditPredictionId("test-1".into()),
1131 EditPredictionRejectReason::Discarded,
1132 false,
1133 cx,
1134 );
1135 ep_store.reject_prediction(
1136 EditPredictionId("test-2".into()),
1137 EditPredictionRejectReason::Canceled,
1138 true,
1139 cx,
1140 );
1141 });
1142
1143 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1144 cx.run_until_parked();
1145
1146 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1147 respond_tx.send(()).unwrap();
1148
1149 // batched
1150 assert_eq!(reject_request.rejections.len(), 2);
1151 assert_eq!(
1152 reject_request.rejections[0],
1153 EditPredictionRejection {
1154 request_id: "test-1".to_string(),
1155 reason: EditPredictionRejectReason::Discarded,
1156 was_shown: false
1157 }
1158 );
1159 assert_eq!(
1160 reject_request.rejections[1],
1161 EditPredictionRejection {
1162 request_id: "test-2".to_string(),
1163 reason: EditPredictionRejectReason::Canceled,
1164 was_shown: true
1165 }
1166 );
1167
1168 // Reaching batch size limit sends without debounce
1169 ep_store.update(cx, |ep_store, cx| {
1170 for i in 0..70 {
1171 ep_store.reject_prediction(
1172 EditPredictionId(format!("batch-{}", i).into()),
1173 EditPredictionRejectReason::Discarded,
1174 false,
1175 cx,
1176 );
1177 }
1178 });
1179
1180 // First MAX/2 items are sent immediately
1181 cx.run_until_parked();
1182 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1183 respond_tx.send(()).unwrap();
1184
1185 assert_eq!(reject_request.rejections.len(), 50);
1186 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1187 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1188
1189 // Remaining items are debounced with the next batch
1190 cx.executor().advance_clock(Duration::from_secs(15));
1191 cx.run_until_parked();
1192
1193 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1194 respond_tx.send(()).unwrap();
1195
1196 assert_eq!(reject_request.rejections.len(), 20);
1197 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1198 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1199
1200 // Request failure
1201 ep_store.update(cx, |ep_store, cx| {
1202 ep_store.reject_prediction(
1203 EditPredictionId("retry-1".into()),
1204 EditPredictionRejectReason::Discarded,
1205 false,
1206 cx,
1207 );
1208 });
1209
1210 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1211 cx.run_until_parked();
1212
1213 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1214 assert_eq!(reject_request.rejections.len(), 1);
1215 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1216 // Simulate failure
1217 drop(_respond_tx);
1218
1219 // Add another rejection
1220 ep_store.update(cx, |ep_store, cx| {
1221 ep_store.reject_prediction(
1222 EditPredictionId("retry-2".into()),
1223 EditPredictionRejectReason::Discarded,
1224 false,
1225 cx,
1226 );
1227 });
1228
1229 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1230 cx.run_until_parked();
1231
1232 // Retry should include both the failed item and the new one
1233 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1234 respond_tx.send(()).unwrap();
1235
1236 assert_eq!(reject_request.rejections.len(), 2);
1237 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1238 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1239}
1240
1241// Skipped until we start including diagnostics in prompt
1242// #[gpui::test]
1243// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1244// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1245// let fs = FakeFs::new(cx.executor());
1246// fs.insert_tree(
1247// "/root",
1248// json!({
1249// "foo.md": "Hello!\nBye"
1250// }),
1251// )
1252// .await;
1253// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1254
1255// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1256// let diagnostic = lsp::Diagnostic {
1257// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1258// severity: Some(lsp::DiagnosticSeverity::ERROR),
1259// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1260// ..Default::default()
1261// };
1262
1263// project.update(cx, |project, cx| {
1264// project.lsp_store().update(cx, |lsp_store, cx| {
1265// // Create some diagnostics
1266// lsp_store
1267// .update_diagnostics(
1268// LanguageServerId(0),
1269// lsp::PublishDiagnosticsParams {
1270// uri: path_to_buffer_uri.clone(),
1271// diagnostics: vec![diagnostic],
1272// version: None,
1273// },
1274// None,
1275// language::DiagnosticSourceKind::Pushed,
1276// &[],
1277// cx,
1278// )
1279// .unwrap();
1280// });
1281// });
1282
1283// let buffer = project
1284// .update(cx, |project, cx| {
1285// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1286// project.open_buffer(path, cx)
1287// })
1288// .await
1289// .unwrap();
1290
1291// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1292// let position = snapshot.anchor_before(language::Point::new(0, 0));
1293
1294// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1295// ep_store.request_prediction(&project, &buffer, position, cx)
1296// });
1297
1298// let (request, _respond_tx) = req_rx.next().await.unwrap();
1299
1300// assert_eq!(request.diagnostic_groups.len(), 1);
1301// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1302// .unwrap();
1303// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1304// assert_eq!(
1305// value,
1306// json!({
1307// "entries": [{
1308// "range": {
1309// "start": 8,
1310// "end": 10
1311// },
1312// "diagnostic": {
1313// "source": null,
1314// "code": null,
1315// "code_description": null,
1316// "severity": 1,
1317// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1318// "markdown": null,
1319// "group_id": 0,
1320// "is_primary": true,
1321// "is_disk_based": false,
1322// "is_unnecessary": false,
1323// "source_kind": "Pushed",
1324// "data": null,
1325// "underline": true
1326// }
1327// }],
1328// "primary_ix": 0
1329// })
1330// );
1331// }
1332
1333// Generate a model response that would apply the given diff to the active file.
1334fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1335 let excerpt =
1336 request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1337 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1338
1339 PredictEditsV3Response {
1340 request_id: Uuid::new_v4().to_string(),
1341 output: new_excerpt,
1342 }
1343}
1344
1345fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1346 zeta_prompt::format_zeta_prompt(&request.input, request.prompt_version)
1347}
1348
1349struct RequestChannels {
1350 predict: mpsc::UnboundedReceiver<(
1351 PredictEditsV3Request,
1352 oneshot::Sender<PredictEditsV3Response>,
1353 )>,
1354 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1355}
1356
1357fn init_test_with_fake_client(
1358 cx: &mut TestAppContext,
1359) -> (Entity<EditPredictionStore>, RequestChannels) {
1360 cx.update(move |cx| {
1361 let settings_store = SettingsStore::test(cx);
1362 cx.set_global(settings_store);
1363 zlog::init_test();
1364
1365 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1366 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1367
1368 let http_client = FakeHttpClient::create({
1369 move |req| {
1370 let uri = req.uri().path().to_string();
1371 let mut body = req.into_body();
1372 let predict_req_tx = predict_req_tx.clone();
1373 let reject_req_tx = reject_req_tx.clone();
1374 async move {
1375 let resp = match uri.as_str() {
1376 "/client/llm_tokens" => serde_json::to_string(&json!({
1377 "token": "test"
1378 }))
1379 .unwrap(),
1380 "/predict_edits/v3" => {
1381 let mut buf = Vec::new();
1382 body.read_to_end(&mut buf).await.ok();
1383 let req = serde_json::from_slice(&buf).unwrap();
1384
1385 let (res_tx, res_rx) = oneshot::channel();
1386 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1387 serde_json::to_string(&res_rx.await?).unwrap()
1388 }
1389 "/predict_edits/reject" => {
1390 let mut buf = Vec::new();
1391 body.read_to_end(&mut buf).await.ok();
1392 let req = serde_json::from_slice(&buf).unwrap();
1393
1394 let (res_tx, res_rx) = oneshot::channel();
1395 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1396 serde_json::to_string(&res_rx.await?).unwrap()
1397 }
1398 _ => {
1399 panic!("Unexpected path: {}", uri)
1400 }
1401 };
1402
1403 Ok(Response::builder().body(resp.into()).unwrap())
1404 }
1405 }
1406 });
1407
1408 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1409 client.cloud_client().set_credentials(1, "test".into());
1410
1411 language_model::init(client.clone(), cx);
1412
1413 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1414 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1415
1416 (
1417 ep_store,
1418 RequestChannels {
1419 predict: predict_req_rx,
1420 reject: reject_req_rx,
1421 },
1422 )
1423 })
1424}
1425
1426const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1427
1428#[gpui::test]
1429async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1430 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1431 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1432 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1433 });
1434
1435 let edit_preview = cx
1436 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1437 .await;
1438
1439 let prediction = EditPrediction {
1440 edits,
1441 cursor_position: None,
1442 edit_preview,
1443 buffer: buffer.clone(),
1444 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1445 id: EditPredictionId("the-id".into()),
1446 inputs: ZetaPromptInput {
1447 events: Default::default(),
1448 related_files: Default::default(),
1449 cursor_path: Path::new("").into(),
1450 cursor_excerpt: "".into(),
1451 editable_range_in_excerpt: 0..0,
1452 cursor_offset_in_excerpt: 0,
1453 excerpt_start_row: None,
1454 },
1455 buffer_snapshotted_at: Instant::now(),
1456 response_received_at: Instant::now(),
1457 };
1458
1459 cx.update(|cx| {
1460 assert_eq!(
1461 from_completion_edits(
1462 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1463 &buffer,
1464 cx
1465 ),
1466 vec![(2..5, "REM".into()), (9..11, "".into())]
1467 );
1468
1469 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1470 assert_eq!(
1471 from_completion_edits(
1472 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1473 &buffer,
1474 cx
1475 ),
1476 vec![(2..2, "REM".into()), (6..8, "".into())]
1477 );
1478
1479 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1480 assert_eq!(
1481 from_completion_edits(
1482 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1483 &buffer,
1484 cx
1485 ),
1486 vec![(2..5, "REM".into()), (9..11, "".into())]
1487 );
1488
1489 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1490 assert_eq!(
1491 from_completion_edits(
1492 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1493 &buffer,
1494 cx
1495 ),
1496 vec![(3..3, "EM".into()), (7..9, "".into())]
1497 );
1498
1499 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1500 assert_eq!(
1501 from_completion_edits(
1502 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1503 &buffer,
1504 cx
1505 ),
1506 vec![(4..4, "M".into()), (8..10, "".into())]
1507 );
1508
1509 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1510 assert_eq!(
1511 from_completion_edits(
1512 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1513 &buffer,
1514 cx
1515 ),
1516 vec![(9..11, "".into())]
1517 );
1518
1519 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1520 assert_eq!(
1521 from_completion_edits(
1522 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1523 &buffer,
1524 cx
1525 ),
1526 vec![(4..4, "M".into()), (8..10, "".into())]
1527 );
1528
1529 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1530 assert_eq!(
1531 from_completion_edits(
1532 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1533 &buffer,
1534 cx
1535 ),
1536 vec![(4..4, "M".into())]
1537 );
1538
1539 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1540 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1541 })
1542}
1543
1544#[gpui::test]
1545async fn test_clean_up_diff(cx: &mut TestAppContext) {
1546 init_test(cx);
1547
1548 assert_eq!(
1549 apply_edit_prediction(
1550 indoc! {"
1551 fn main() {
1552 let word_1 = \"lorem\";
1553 let range = word.len()..word.len();
1554 }
1555 "},
1556 indoc! {"
1557 <|editable_region_start|>
1558 fn main() {
1559 let word_1 = \"lorem\";
1560 let range = word_1.len()..word_1.len();
1561 }
1562
1563 <|editable_region_end|>
1564 "},
1565 cx,
1566 )
1567 .await,
1568 indoc! {"
1569 fn main() {
1570 let word_1 = \"lorem\";
1571 let range = word_1.len()..word_1.len();
1572 }
1573 "},
1574 );
1575
1576 assert_eq!(
1577 apply_edit_prediction(
1578 indoc! {"
1579 fn main() {
1580 let story = \"the quick\"
1581 }
1582 "},
1583 indoc! {"
1584 <|editable_region_start|>
1585 fn main() {
1586 let story = \"the quick brown fox jumps over the lazy dog\";
1587 }
1588
1589 <|editable_region_end|>
1590 "},
1591 cx,
1592 )
1593 .await,
1594 indoc! {"
1595 fn main() {
1596 let story = \"the quick brown fox jumps over the lazy dog\";
1597 }
1598 "},
1599 );
1600}
1601
1602#[gpui::test]
1603async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1604 init_test(cx);
1605
1606 let buffer_content = "lorem\n";
1607 let completion_response = indoc! {"
1608 ```animals.js
1609 <|start_of_file|>
1610 <|editable_region_start|>
1611 lorem
1612 ipsum
1613 <|editable_region_end|>
1614 ```"};
1615
1616 assert_eq!(
1617 apply_edit_prediction(buffer_content, completion_response, cx).await,
1618 "lorem\nipsum"
1619 );
1620}
1621
1622#[gpui::test]
1623async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1624 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1625 // When the buffer ends without a trailing newline, but the model returns output
1626 // with a trailing newline, zeta2 should normalize both sides before diffing
1627 // so no spurious newline is inserted.
1628 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1629 let fs = FakeFs::new(cx.executor());
1630
1631 // Single line buffer with no trailing newline
1632 fs.insert_tree(
1633 "/root",
1634 json!({
1635 "foo.txt": "hello"
1636 }),
1637 )
1638 .await;
1639 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1640
1641 let buffer = project
1642 .update(cx, |project, cx| {
1643 let path = project
1644 .find_project_path(path!("root/foo.txt"), cx)
1645 .unwrap();
1646 project.open_buffer(path, cx)
1647 })
1648 .await
1649 .unwrap();
1650
1651 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1652 let position = snapshot.anchor_before(language::Point::new(0, 5));
1653
1654 ep_store.update(cx, |ep_store, cx| {
1655 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1656 });
1657
1658 let (_request, respond_tx) = requests.predict.next().await.unwrap();
1659
1660 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1661 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1662 let response = PredictEditsV3Response {
1663 request_id: Uuid::new_v4().to_string(),
1664 output: "hello world\n".to_string(),
1665 };
1666 respond_tx.send(response).unwrap();
1667
1668 cx.run_until_parked();
1669
1670 // The prediction should insert " world" without adding a newline
1671 ep_store.update(cx, |ep_store, cx| {
1672 let prediction = ep_store
1673 .prediction_at(&buffer, None, &project, cx)
1674 .expect("should have prediction");
1675 let edits: Vec<_> = prediction
1676 .edits
1677 .iter()
1678 .map(|(range, text)| {
1679 let snapshot = buffer.read(cx).snapshot();
1680 (range.to_offset(&snapshot), text.clone())
1681 })
1682 .collect();
1683 assert_eq!(edits, vec![(5..5, " world".into())]);
1684 });
1685}
1686
1687#[gpui::test]
1688async fn test_can_collect_data(cx: &mut TestAppContext) {
1689 init_test(cx);
1690
1691 let fs = project::FakeFs::new(cx.executor());
1692 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1693 .await;
1694
1695 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1696 let buffer = project
1697 .update(cx, |project, cx| {
1698 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1699 })
1700 .await
1701 .unwrap();
1702
1703 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1704 ep_store.update(cx, |ep_store, _cx| {
1705 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1706 });
1707
1708 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1709 assert_eq!(
1710 captured_request.lock().clone().unwrap().can_collect_data,
1711 true
1712 );
1713
1714 ep_store.update(cx, |ep_store, _cx| {
1715 ep_store.data_collection_choice = DataCollectionChoice::Disabled
1716 });
1717
1718 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1719 assert_eq!(
1720 captured_request.lock().clone().unwrap().can_collect_data,
1721 false
1722 );
1723}
1724
1725#[gpui::test]
1726async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1727 init_test(cx);
1728
1729 let fs = project::FakeFs::new(cx.executor());
1730 let project = Project::test(fs.clone(), [], cx).await;
1731
1732 let buffer = cx.new(|_cx| {
1733 Buffer::remote(
1734 language::BufferId::new(1).unwrap(),
1735 ReplicaId::new(1),
1736 language::Capability::ReadWrite,
1737 "fn main() {\n println!(\"Hello\");\n}",
1738 )
1739 });
1740
1741 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1742 ep_store.update(cx, |ep_store, _cx| {
1743 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1744 });
1745
1746 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1747 assert_eq!(
1748 captured_request.lock().clone().unwrap().can_collect_data,
1749 false
1750 );
1751}
1752
1753#[gpui::test]
1754async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1755 init_test(cx);
1756
1757 let fs = project::FakeFs::new(cx.executor());
1758 fs.insert_tree(
1759 path!("/project"),
1760 json!({
1761 "LICENSE": BSD_0_TXT,
1762 ".env": "SECRET_KEY=secret"
1763 }),
1764 )
1765 .await;
1766
1767 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1768 let buffer = project
1769 .update(cx, |project, cx| {
1770 project.open_local_buffer("/project/.env", cx)
1771 })
1772 .await
1773 .unwrap();
1774
1775 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1776 ep_store.update(cx, |ep_store, _cx| {
1777 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1778 });
1779
1780 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1781 assert_eq!(
1782 captured_request.lock().clone().unwrap().can_collect_data,
1783 false
1784 );
1785}
1786
1787#[gpui::test]
1788async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1789 init_test(cx);
1790
1791 let fs = project::FakeFs::new(cx.executor());
1792 let project = Project::test(fs.clone(), [], cx).await;
1793 let buffer = cx.new(|cx| Buffer::local("", cx));
1794
1795 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1796 ep_store.update(cx, |ep_store, _cx| {
1797 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1798 });
1799
1800 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1801 assert_eq!(
1802 captured_request.lock().clone().unwrap().can_collect_data,
1803 false
1804 );
1805}
1806
1807#[gpui::test]
1808async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1809 init_test(cx);
1810
1811 let fs = project::FakeFs::new(cx.executor());
1812 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1813 .await;
1814
1815 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1816 let buffer = project
1817 .update(cx, |project, cx| {
1818 project.open_local_buffer("/project/main.rs", cx)
1819 })
1820 .await
1821 .unwrap();
1822
1823 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1824 ep_store.update(cx, |ep_store, _cx| {
1825 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1826 });
1827
1828 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1829 assert_eq!(
1830 captured_request.lock().clone().unwrap().can_collect_data,
1831 false
1832 );
1833}
1834
1835#[gpui::test]
1836async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1837 init_test(cx);
1838
1839 let fs = project::FakeFs::new(cx.executor());
1840 fs.insert_tree(
1841 path!("/open_source_worktree"),
1842 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1843 )
1844 .await;
1845 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1846 .await;
1847
1848 let project = Project::test(
1849 fs.clone(),
1850 [
1851 path!("/open_source_worktree").as_ref(),
1852 path!("/closed_source_worktree").as_ref(),
1853 ],
1854 cx,
1855 )
1856 .await;
1857 let buffer = project
1858 .update(cx, |project, cx| {
1859 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1860 })
1861 .await
1862 .unwrap();
1863
1864 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1865 ep_store.update(cx, |ep_store, _cx| {
1866 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1867 });
1868
1869 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1870 assert_eq!(
1871 captured_request.lock().clone().unwrap().can_collect_data,
1872 true
1873 );
1874
1875 let closed_source_file = project
1876 .update(cx, |project, cx| {
1877 let worktree2 = project
1878 .worktree_for_root_name("closed_source_worktree", cx)
1879 .unwrap();
1880 worktree2.update(cx, |worktree2, cx| {
1881 worktree2.load_file(rel_path("main.rs"), cx)
1882 })
1883 })
1884 .await
1885 .unwrap()
1886 .file;
1887
1888 buffer.update(cx, |buffer, cx| {
1889 buffer.file_updated(closed_source_file, cx);
1890 });
1891
1892 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1893 assert_eq!(
1894 captured_request.lock().clone().unwrap().can_collect_data,
1895 false
1896 );
1897}
1898
1899#[gpui::test]
1900async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1901 init_test(cx);
1902
1903 let fs = project::FakeFs::new(cx.executor());
1904 fs.insert_tree(
1905 path!("/worktree1"),
1906 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1907 )
1908 .await;
1909 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1910 .await;
1911
1912 let project = Project::test(
1913 fs.clone(),
1914 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1915 cx,
1916 )
1917 .await;
1918 let buffer = project
1919 .update(cx, |project, cx| {
1920 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1921 })
1922 .await
1923 .unwrap();
1924 let private_buffer = project
1925 .update(cx, |project, cx| {
1926 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1927 })
1928 .await
1929 .unwrap();
1930
1931 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1932 ep_store.update(cx, |ep_store, _cx| {
1933 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1934 });
1935
1936 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1937 assert_eq!(
1938 captured_request.lock().clone().unwrap().can_collect_data,
1939 true
1940 );
1941
1942 // this has a side effect of registering the buffer to watch for edits
1943 run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1944 assert_eq!(
1945 captured_request.lock().clone().unwrap().can_collect_data,
1946 false
1947 );
1948
1949 private_buffer.update(cx, |private_buffer, cx| {
1950 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1951 });
1952
1953 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1954 assert_eq!(
1955 captured_request.lock().clone().unwrap().can_collect_data,
1956 false
1957 );
1958
1959 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1960 // included
1961 buffer.update(cx, |buffer, cx| {
1962 buffer.edit(
1963 [(
1964 0..0,
1965 " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1966 )],
1967 None,
1968 cx,
1969 );
1970 });
1971
1972 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1973 assert_eq!(
1974 captured_request.lock().clone().unwrap().can_collect_data,
1975 true
1976 );
1977}
1978
1979fn init_test(cx: &mut TestAppContext) {
1980 cx.update(|cx| {
1981 let settings_store = SettingsStore::test(cx);
1982 cx.set_global(settings_store);
1983 });
1984}
1985
1986async fn apply_edit_prediction(
1987 buffer_content: &str,
1988 completion_response: &str,
1989 cx: &mut TestAppContext,
1990) -> String {
1991 let fs = project::FakeFs::new(cx.executor());
1992 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1993 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1994 let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1995 *response.lock() = completion_response.to_string();
1996 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1997 buffer.update(cx, |buffer, cx| {
1998 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1999 });
2000 buffer.read_with(cx, |buffer, _| buffer.text())
2001}
2002
2003async fn run_edit_prediction(
2004 buffer: &Entity<Buffer>,
2005 project: &Entity<Project>,
2006 ep_store: &Entity<EditPredictionStore>,
2007 cx: &mut TestAppContext,
2008) -> EditPrediction {
2009 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2010 ep_store.update(cx, |ep_store, cx| {
2011 ep_store.register_buffer(buffer, &project, cx)
2012 });
2013 cx.background_executor.run_until_parked();
2014 let prediction_task = ep_store.update(cx, |ep_store, cx| {
2015 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2016 });
2017 prediction_task.await.unwrap().unwrap().prediction.unwrap()
2018}
2019
2020async fn make_test_ep_store(
2021 project: &Entity<Project>,
2022 cx: &mut TestAppContext,
2023) -> (
2024 Entity<EditPredictionStore>,
2025 Arc<Mutex<Option<PredictEditsBody>>>,
2026 Arc<Mutex<String>>,
2027) {
2028 let default_response = indoc! {"
2029 ```main.rs
2030 <|start_of_file|>
2031 <|editable_region_start|>
2032 hello world
2033 <|editable_region_end|>
2034 ```"
2035 };
2036 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2037 let completion_response: Arc<Mutex<String>> =
2038 Arc::new(Mutex::new(default_response.to_string()));
2039 let http_client = FakeHttpClient::create({
2040 let captured_request = captured_request.clone();
2041 let completion_response = completion_response.clone();
2042 let mut next_request_id = 0;
2043 move |req| {
2044 let captured_request = captured_request.clone();
2045 let completion_response = completion_response.clone();
2046 async move {
2047 match (req.method(), req.uri().path()) {
2048 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2049 .status(200)
2050 .body(
2051 serde_json::to_string(&CreateLlmTokenResponse {
2052 token: LlmToken("the-llm-token".to_string()),
2053 })
2054 .unwrap()
2055 .into(),
2056 )
2057 .unwrap()),
2058 (&Method::POST, "/predict_edits/v2") => {
2059 let mut request_body = String::new();
2060 req.into_body().read_to_string(&mut request_body).await?;
2061 *captured_request.lock() =
2062 Some(serde_json::from_str(&request_body).unwrap());
2063 next_request_id += 1;
2064 Ok(http_client::Response::builder()
2065 .status(200)
2066 .body(
2067 serde_json::to_string(&PredictEditsResponse {
2068 request_id: format!("request-{next_request_id}"),
2069 output_excerpt: completion_response.lock().clone(),
2070 })
2071 .unwrap()
2072 .into(),
2073 )
2074 .unwrap())
2075 }
2076 _ => Ok(http_client::Response::builder()
2077 .status(404)
2078 .body("Not Found".into())
2079 .unwrap()),
2080 }
2081 }
2082 }
2083 });
2084
2085 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2086 cx.update(|cx| {
2087 RefreshLlmTokenListener::register(client.clone(), cx);
2088 });
2089 let _server = FakeServer::for_client(42, &client, cx).await;
2090
2091 let ep_store = cx.new(|cx| {
2092 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2093 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2094
2095 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2096 for worktree in worktrees {
2097 let worktree_id = worktree.read(cx).id();
2098 ep_store
2099 .get_or_init_project(project, cx)
2100 .license_detection_watchers
2101 .entry(worktree_id)
2102 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2103 }
2104
2105 ep_store
2106 });
2107
2108 (ep_store, captured_request, completion_response)
2109}
2110
2111fn to_completion_edits(
2112 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2113 buffer: &Entity<Buffer>,
2114 cx: &App,
2115) -> Vec<(Range<Anchor>, Arc<str>)> {
2116 let buffer = buffer.read(cx);
2117 iterator
2118 .into_iter()
2119 .map(|(range, text)| {
2120 (
2121 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2122 text,
2123 )
2124 })
2125 .collect()
2126}
2127
2128fn from_completion_edits(
2129 editor_edits: &[(Range<Anchor>, Arc<str>)],
2130 buffer: &Entity<Buffer>,
2131 cx: &App,
2132) -> Vec<(Range<usize>, Arc<str>)> {
2133 let buffer = buffer.read(cx);
2134 editor_edits
2135 .iter()
2136 .map(|(range, text)| {
2137 (
2138 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2139 text.clone(),
2140 )
2141 })
2142 .collect()
2143}
2144
2145#[gpui::test]
2146async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2147 init_test(cx);
2148
2149 let fs = FakeFs::new(cx.executor());
2150 fs.insert_tree(
2151 "/project",
2152 serde_json::json!({
2153 "main.rs": "fn main() {\n \n}\n"
2154 }),
2155 )
2156 .await;
2157
2158 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2159
2160 let http_client = FakeHttpClient::create(|_req| async move {
2161 Ok(gpui::http_client::Response::builder()
2162 .status(401)
2163 .body("Unauthorized".into())
2164 .unwrap())
2165 });
2166
2167 let client =
2168 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2169 cx.update(|cx| {
2170 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2171 });
2172
2173 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2174
2175 let buffer = project
2176 .update(cx, |project, cx| {
2177 let path = project
2178 .find_project_path(path!("/project/main.rs"), cx)
2179 .unwrap();
2180 project.open_buffer(path, cx)
2181 })
2182 .await
2183 .unwrap();
2184
2185 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2186 ep_store.update(cx, |ep_store, cx| {
2187 ep_store.register_buffer(&buffer, &project, cx)
2188 });
2189 cx.background_executor.run_until_parked();
2190
2191 let completion_task = ep_store.update(cx, |ep_store, cx| {
2192 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2193 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2194 });
2195
2196 let result = completion_task.await;
2197 assert!(
2198 result.is_err(),
2199 "Without authentication and without custom URL, prediction should fail"
2200 );
2201}
2202
2203#[gpui::test]
2204async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2205 init_test(cx);
2206
2207 let fs = FakeFs::new(cx.executor());
2208 fs.insert_tree(
2209 "/project",
2210 serde_json::json!({
2211 "main.rs": "fn main() {\n \n}\n"
2212 }),
2213 )
2214 .await;
2215
2216 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2217
2218 let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2219 let predict_called_clone = predict_called.clone();
2220
2221 let http_client = FakeHttpClient::create({
2222 move |req| {
2223 let uri = req.uri().path().to_string();
2224 let predict_called = predict_called_clone.clone();
2225 async move {
2226 if uri.contains("predict") {
2227 predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2228 Ok(gpui::http_client::Response::builder()
2229 .body(
2230 serde_json::to_string(&open_ai::Response {
2231 id: "test-123".to_string(),
2232 object: "chat.completion".to_string(),
2233 created: 0,
2234 model: "test".to_string(),
2235 usage: open_ai::Usage {
2236 prompt_tokens: 0,
2237 completion_tokens: 0,
2238 total_tokens: 0,
2239 },
2240 choices: vec![open_ai::Choice {
2241 index: 0,
2242 message: open_ai::RequestMessage::Assistant {
2243 content: Some(open_ai::MessageContent::Plain(
2244 indoc! {"
2245 ```main.rs
2246 <|start_of_file|>
2247 <|editable_region_start|>
2248 fn main() {
2249 println!(\"Hello, world!\");
2250 }
2251 <|editable_region_end|>
2252 ```
2253 "}
2254 .to_string(),
2255 )),
2256 tool_calls: vec![],
2257 },
2258 finish_reason: Some("stop".to_string()),
2259 }],
2260 })
2261 .unwrap()
2262 .into(),
2263 )
2264 .unwrap())
2265 } else {
2266 Ok(gpui::http_client::Response::builder()
2267 .status(401)
2268 .body("Unauthorized".into())
2269 .unwrap())
2270 }
2271 }
2272 }
2273 });
2274
2275 let client =
2276 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2277 cx.update(|cx| {
2278 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2279 });
2280
2281 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2282
2283 let buffer = project
2284 .update(cx, |project, cx| {
2285 let path = project
2286 .find_project_path(path!("/project/main.rs"), cx)
2287 .unwrap();
2288 project.open_buffer(path, cx)
2289 })
2290 .await
2291 .unwrap();
2292
2293 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2294 ep_store.update(cx, |ep_store, cx| {
2295 ep_store.register_buffer(&buffer, &project, cx)
2296 });
2297 cx.background_executor.run_until_parked();
2298
2299 let completion_task = ep_store.update(cx, |ep_store, cx| {
2300 ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2301 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2302 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2303 });
2304
2305 let _ = completion_task.await;
2306
2307 assert!(
2308 predict_called.load(std::sync::atomic::Ordering::SeqCst),
2309 "With custom URL, predict endpoint should be called even without authentication"
2310 );
2311}
2312
2313#[gpui::test]
2314fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2315 let buffer = cx.new(|cx| {
2316 Buffer::local(
2317 indoc! {"
2318 zero
2319 one
2320 two
2321 three
2322 four
2323 five
2324 six
2325 seven
2326 eight
2327 nine
2328 ten
2329 eleven
2330 twelve
2331 thirteen
2332 fourteen
2333 fifteen
2334 sixteen
2335 seventeen
2336 eighteen
2337 nineteen
2338 twenty
2339 twenty-one
2340 twenty-two
2341 twenty-three
2342 twenty-four
2343 "},
2344 cx,
2345 )
2346 });
2347
2348 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2349
2350 buffer.update(cx, |buffer, cx| {
2351 let point = Point::new(12, 0);
2352 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2353 let point = Point::new(8, 0);
2354 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2355 });
2356
2357 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2358
2359 let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2360
2361 assert_eq!(
2362 diff,
2363 indoc! {"
2364 @@ -6,10 +6,12 @@
2365 five
2366 six
2367 seven
2368 +FIRST INSERTION
2369 eight
2370 nine
2371 ten
2372 eleven
2373 +SECOND INSERTION
2374 twelve
2375 thirteen
2376 fourteen
2377 "}
2378 );
2379}
2380
2381#[ctor::ctor]
2382fn init_logger() {
2383 zlog::init_test();
2384}