1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
3use client::{UserStore, test::FakeServer};
4use clock::{FakeSystemClock, ReplicaId};
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
8 RejectEditPredictionsBody,
9 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
10};
11use futures::{
12 AsyncReadExt, StreamExt,
13 channel::{mpsc, oneshot},
14};
15use gpui::App;
16use gpui::{
17 Entity, TestAppContext,
18 http_client::{FakeHttpClient, Response},
19};
20use indoc::indoc;
21use language::{Buffer, Point};
22use lsp::LanguageServerId;
23use parking_lot::Mutex;
24use pretty_assertions::{assert_eq, assert_matches};
25use project::{FakeFs, Project};
26use serde_json::json;
27use settings::SettingsStore;
28use std::{path::Path, sync::Arc, time::Duration};
29use util::{path, rel_path::rel_path};
30use uuid::Uuid;
31use zeta_prompt::ZetaPromptInput;
32
33use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
34
35#[gpui::test]
36async fn test_current_state(cx: &mut TestAppContext) {
37 let (ep_store, mut requests) = init_test_with_fake_client(cx);
38 let fs = FakeFs::new(cx.executor());
39 fs.insert_tree(
40 "/root",
41 json!({
42 "1.txt": "Hello!\nHow\nBye\n",
43 "2.txt": "Hola!\nComo\nAdios\n"
44 }),
45 )
46 .await;
47 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
48
49 let buffer1 = project
50 .update(cx, |project, cx| {
51 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
52 project.set_active_path(Some(path.clone()), cx);
53 project.open_buffer(path, cx)
54 })
55 .await
56 .unwrap();
57 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
58 let position = snapshot1.anchor_before(language::Point::new(1, 3));
59
60 ep_store.update(cx, |ep_store, cx| {
61 ep_store.register_project(&project, cx);
62 ep_store.register_buffer(&buffer1, &project, cx);
63 });
64
65 // Prediction for current file
66
67 ep_store.update(cx, |ep_store, cx| {
68 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
69 });
70 let (request, respond_tx) = requests.predict.next().await.unwrap();
71
72 respond_tx
73 .send(model_response(
74 &request,
75 indoc! {r"
76 --- a/root/1.txt
77 +++ b/root/1.txt
78 @@ ... @@
79 Hello!
80 -How
81 +How are you?
82 Bye
83 "},
84 ))
85 .unwrap();
86
87 cx.run_until_parked();
88
89 ep_store.update(cx, |ep_store, cx| {
90 let prediction = ep_store
91 .prediction_at(&buffer1, None, &project, cx)
92 .unwrap();
93 assert_matches!(prediction, BufferEditPrediction::Local { .. });
94 });
95
96 ep_store.update(cx, |ep_store, _cx| {
97 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
98 });
99
100 // Prediction for diagnostic in another file
101
102 let diagnostic = lsp::Diagnostic {
103 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
104 severity: Some(lsp::DiagnosticSeverity::ERROR),
105 message: "Sentence is incomplete".to_string(),
106 ..Default::default()
107 };
108
109 project.update(cx, |project, cx| {
110 project.lsp_store().update(cx, |lsp_store, cx| {
111 lsp_store
112 .update_diagnostics(
113 LanguageServerId(0),
114 lsp::PublishDiagnosticsParams {
115 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
116 diagnostics: vec![diagnostic],
117 version: None,
118 },
119 None,
120 language::DiagnosticSourceKind::Pushed,
121 &[],
122 cx,
123 )
124 .unwrap();
125 });
126 });
127
128 let (request, respond_tx) = requests.predict.next().await.unwrap();
129 respond_tx
130 .send(model_response(
131 &request,
132 indoc! {r#"
133 --- a/root/2.txt
134 +++ b/root/2.txt
135 @@ ... @@
136 Hola!
137 -Como
138 +Como estas?
139 Adios
140 "#},
141 ))
142 .unwrap();
143 cx.run_until_parked();
144
145 ep_store.update(cx, |ep_store, cx| {
146 let prediction = ep_store
147 .prediction_at(&buffer1, None, &project, cx)
148 .unwrap();
149 assert_matches!(
150 prediction,
151 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
152 );
153 });
154
155 let buffer2 = project
156 .update(cx, |project, cx| {
157 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
158 project.open_buffer(path, cx)
159 })
160 .await
161 .unwrap();
162
163 ep_store.update(cx, |ep_store, cx| {
164 let prediction = ep_store
165 .prediction_at(&buffer2, None, &project, cx)
166 .unwrap();
167 assert_matches!(prediction, BufferEditPrediction::Local { .. });
168 });
169}
170
171#[gpui::test]
172async fn test_simple_request(cx: &mut TestAppContext) {
173 let (ep_store, mut requests) = init_test_with_fake_client(cx);
174 let fs = FakeFs::new(cx.executor());
175 fs.insert_tree(
176 "/root",
177 json!({
178 "foo.md": "Hello!\nHow\nBye\n"
179 }),
180 )
181 .await;
182 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
183
184 let buffer = project
185 .update(cx, |project, cx| {
186 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
187 project.open_buffer(path, cx)
188 })
189 .await
190 .unwrap();
191 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
192 let position = snapshot.anchor_before(language::Point::new(1, 3));
193
194 let prediction_task = ep_store.update(cx, |ep_store, cx| {
195 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
196 });
197
198 let (request, respond_tx) = requests.predict.next().await.unwrap();
199
200 // TODO Put back when we have a structured request again
201 // assert_eq!(
202 // request.excerpt_path.as_ref(),
203 // Path::new(path!("root/foo.md"))
204 // );
205 // assert_eq!(
206 // request.cursor_point,
207 // Point {
208 // line: Line(1),
209 // column: 3
210 // }
211 // );
212
213 respond_tx
214 .send(model_response(
215 &request,
216 indoc! { r"
217 --- a/root/foo.md
218 +++ b/root/foo.md
219 @@ ... @@
220 Hello!
221 -How
222 +How are you?
223 Bye
224 "},
225 ))
226 .unwrap();
227
228 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
229
230 assert_eq!(prediction.edits.len(), 1);
231 assert_eq!(
232 prediction.edits[0].0.to_point(&snapshot).start,
233 language::Point::new(1, 3)
234 );
235 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
236}
237
238#[gpui::test]
239async fn test_request_events(cx: &mut TestAppContext) {
240 let (ep_store, mut requests) = init_test_with_fake_client(cx);
241 let fs = FakeFs::new(cx.executor());
242 fs.insert_tree(
243 "/root",
244 json!({
245 "foo.md": "Hello!\n\nBye\n"
246 }),
247 )
248 .await;
249 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
250
251 let buffer = project
252 .update(cx, |project, cx| {
253 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
254 project.open_buffer(path, cx)
255 })
256 .await
257 .unwrap();
258
259 ep_store.update(cx, |ep_store, cx| {
260 ep_store.register_buffer(&buffer, &project, cx);
261 });
262
263 buffer.update(cx, |buffer, cx| {
264 buffer.edit(vec![(7..7, "How")], None, cx);
265 });
266
267 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
268 let position = snapshot.anchor_before(language::Point::new(1, 3));
269
270 let prediction_task = ep_store.update(cx, |ep_store, cx| {
271 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
272 });
273
274 let (request, respond_tx) = requests.predict.next().await.unwrap();
275
276 let prompt = prompt_from_request(&request);
277 assert!(
278 prompt.contains(indoc! {"
279 --- a/root/foo.md
280 +++ b/root/foo.md
281 @@ -1,3 +1,3 @@
282 Hello!
283 -
284 +How
285 Bye
286 "}),
287 "{prompt}"
288 );
289
290 respond_tx
291 .send(model_response(
292 &request,
293 indoc! {r#"
294 --- a/root/foo.md
295 +++ b/root/foo.md
296 @@ ... @@
297 Hello!
298 -How
299 +How are you?
300 Bye
301 "#},
302 ))
303 .unwrap();
304
305 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
306
307 assert_eq!(prediction.edits.len(), 1);
308 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
309}
310
311#[gpui::test]
312async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
313 let (ep_store, _requests) = init_test_with_fake_client(cx);
314 let fs = FakeFs::new(cx.executor());
315 fs.insert_tree(
316 "/root",
317 json!({
318 "foo.md": "Hello!\n\nBye\n"
319 }),
320 )
321 .await;
322 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
323
324 let buffer = project
325 .update(cx, |project, cx| {
326 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
327 project.open_buffer(path, cx)
328 })
329 .await
330 .unwrap();
331
332 ep_store.update(cx, |ep_store, cx| {
333 ep_store.register_buffer(&buffer, &project, cx);
334 });
335
336 // First burst: insert "How"
337 buffer.update(cx, |buffer, cx| {
338 buffer.edit(vec![(7..7, "How")], None, cx);
339 });
340
341 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
342 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
343 cx.run_until_parked();
344
345 // Second burst: append " are you?" immediately after "How" on the same line.
346 //
347 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
348 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
349 buffer.update(cx, |buffer, cx| {
350 buffer.edit(vec![(10..10, " are you?")], None, cx);
351 });
352
353 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
354 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
355 buffer.update(cx, |buffer, cx| {
356 buffer.edit(vec![(19..19, "!")], None, cx);
357 });
358
359 // Without time-based splitting, there is one event.
360 let events = ep_store.update(cx, |ep_store, cx| {
361 ep_store.edit_history_for_project(&project, cx)
362 });
363 assert_eq!(events.len(), 1);
364 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
365 assert_eq!(
366 diff.as_str(),
367 indoc! {"
368 @@ -1,3 +1,3 @@
369 Hello!
370 -
371 +How are you?!
372 Bye
373 "}
374 );
375
376 // With time-based splitting, there are two distinct events.
377 let events = ep_store.update(cx, |ep_store, cx| {
378 ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
379 });
380 assert_eq!(events.len(), 2);
381 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
382 assert_eq!(
383 diff.as_str(),
384 indoc! {"
385 @@ -1,3 +1,3 @@
386 Hello!
387 -
388 +How
389 Bye
390 "}
391 );
392
393 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
394 assert_eq!(
395 diff.as_str(),
396 indoc! {"
397 @@ -1,3 +1,3 @@
398 Hello!
399 -How
400 +How are you?!
401 Bye
402 "}
403 );
404}
405
406#[gpui::test]
407async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
408 let (ep_store, _requests) = init_test_with_fake_client(cx);
409 let fs = FakeFs::new(cx.executor());
410
411 // Create a file with 30 lines to test line-based coalescing
412 let content = (1..=30)
413 .map(|i| format!("Line {}\n", i))
414 .collect::<String>();
415 fs.insert_tree(
416 "/root",
417 json!({
418 "foo.md": content
419 }),
420 )
421 .await;
422 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
423
424 let buffer = project
425 .update(cx, |project, cx| {
426 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
427 project.open_buffer(path, cx)
428 })
429 .await
430 .unwrap();
431
432 ep_store.update(cx, |ep_store, cx| {
433 ep_store.register_buffer(&buffer, &project, cx);
434 });
435
436 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
437 buffer.update(cx, |buffer, cx| {
438 let start = Point::new(10, 0).to_offset(buffer);
439 let end = Point::new(13, 0).to_offset(buffer);
440 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
441 });
442
443 let events = ep_store.update(cx, |ep_store, cx| {
444 ep_store.edit_history_for_project(&project, cx)
445 });
446 assert_eq!(
447 render_events(&events),
448 indoc! {"
449 @@ -8,9 +8,8 @@
450 Line 8
451 Line 9
452 Line 10
453 -Line 11
454 -Line 12
455 -Line 13
456 +Middle A
457 +Middle B
458 Line 14
459 Line 15
460 Line 16
461 "},
462 "After first edit"
463 );
464
465 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
466 // This tests that coalescing considers the START of the existing range
467 buffer.update(cx, |buffer, cx| {
468 let offset = Point::new(5, 0).to_offset(buffer);
469 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
470 });
471
472 let events = ep_store.update(cx, |ep_store, cx| {
473 ep_store.edit_history_for_project(&project, cx)
474 });
475 assert_eq!(
476 render_events(&events),
477 indoc! {"
478 @@ -3,14 +3,14 @@
479 Line 3
480 Line 4
481 Line 5
482 +Above
483 Line 6
484 Line 7
485 Line 8
486 Line 9
487 Line 10
488 -Line 11
489 -Line 12
490 -Line 13
491 +Middle A
492 +Middle B
493 Line 14
494 Line 15
495 Line 16
496 "},
497 "After inserting above (should coalesce)"
498 );
499
500 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
501 // This tests that coalescing considers the END of the existing range
502 buffer.update(cx, |buffer, cx| {
503 let offset = Point::new(14, 0).to_offset(buffer);
504 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
505 });
506
507 let events = ep_store.update(cx, |ep_store, cx| {
508 ep_store.edit_history_for_project(&project, cx)
509 });
510 assert_eq!(
511 render_events(&events),
512 indoc! {"
513 @@ -3,15 +3,16 @@
514 Line 3
515 Line 4
516 Line 5
517 +Above
518 Line 6
519 Line 7
520 Line 8
521 Line 9
522 Line 10
523 -Line 11
524 -Line 12
525 -Line 13
526 +Middle A
527 +Middle B
528 Line 14
529 +Below
530 Line 15
531 Line 16
532 Line 17
533 "},
534 "After inserting below (should coalesce)"
535 );
536
537 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
538 // This should NOT coalesce - creates a new event
539 buffer.update(cx, |buffer, cx| {
540 let offset = Point::new(25, 0).to_offset(buffer);
541 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
542 });
543
544 let events = ep_store.update(cx, |ep_store, cx| {
545 ep_store.edit_history_for_project(&project, cx)
546 });
547 assert_eq!(
548 render_events(&events),
549 indoc! {"
550 @@ -3,15 +3,16 @@
551 Line 3
552 Line 4
553 Line 5
554 +Above
555 Line 6
556 Line 7
557 Line 8
558 Line 9
559 Line 10
560 -Line 11
561 -Line 12
562 -Line 13
563 +Middle A
564 +Middle B
565 Line 14
566 +Below
567 Line 15
568 Line 16
569 Line 17
570
571 ---
572 @@ -23,6 +23,7 @@
573 Line 22
574 Line 23
575 Line 24
576 +Far below
577 Line 25
578 Line 26
579 Line 27
580 "},
581 "After inserting far below (should NOT coalesce)"
582 );
583}
584
585fn render_events(events: &[StoredEvent]) -> String {
586 events
587 .iter()
588 .map(|e| {
589 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
590 diff.as_str()
591 })
592 .collect::<Vec<_>>()
593 .join("\n---\n")
594}
595
596#[gpui::test]
597async fn test_empty_prediction(cx: &mut TestAppContext) {
598 let (ep_store, mut requests) = init_test_with_fake_client(cx);
599 let fs = FakeFs::new(cx.executor());
600 fs.insert_tree(
601 "/root",
602 json!({
603 "foo.md": "Hello!\nHow\nBye\n"
604 }),
605 )
606 .await;
607 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
608
609 let buffer = project
610 .update(cx, |project, cx| {
611 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
612 project.open_buffer(path, cx)
613 })
614 .await
615 .unwrap();
616 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
617 let position = snapshot.anchor_before(language::Point::new(1, 3));
618
619 ep_store.update(cx, |ep_store, cx| {
620 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
621 });
622
623 let (request, respond_tx) = requests.predict.next().await.unwrap();
624 let response = model_response(&request, "");
625 let id = response.request_id.clone();
626 respond_tx.send(response).unwrap();
627
628 cx.run_until_parked();
629
630 ep_store.update(cx, |ep_store, cx| {
631 assert!(
632 ep_store
633 .prediction_at(&buffer, None, &project, cx)
634 .is_none()
635 );
636 });
637
638 // prediction is reported as rejected
639 let (reject_request, _) = requests.reject.next().await.unwrap();
640
641 assert_eq!(
642 &reject_request.rejections,
643 &[EditPredictionRejection {
644 request_id: id,
645 reason: EditPredictionRejectReason::Empty,
646 was_shown: false
647 }]
648 );
649}
650
651#[gpui::test]
652async fn test_interpolated_empty(cx: &mut TestAppContext) {
653 let (ep_store, mut requests) = init_test_with_fake_client(cx);
654 let fs = FakeFs::new(cx.executor());
655 fs.insert_tree(
656 "/root",
657 json!({
658 "foo.md": "Hello!\nHow\nBye\n"
659 }),
660 )
661 .await;
662 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
663
664 let buffer = project
665 .update(cx, |project, cx| {
666 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
667 project.open_buffer(path, cx)
668 })
669 .await
670 .unwrap();
671 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
672 let position = snapshot.anchor_before(language::Point::new(1, 3));
673
674 ep_store.update(cx, |ep_store, cx| {
675 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
676 });
677
678 let (request, respond_tx) = requests.predict.next().await.unwrap();
679
680 buffer.update(cx, |buffer, cx| {
681 buffer.set_text("Hello!\nHow are you?\nBye", cx);
682 });
683
684 let response = model_response(&request, SIMPLE_DIFF);
685 let id = response.request_id.clone();
686 respond_tx.send(response).unwrap();
687
688 cx.run_until_parked();
689
690 ep_store.update(cx, |ep_store, cx| {
691 assert!(
692 ep_store
693 .prediction_at(&buffer, None, &project, cx)
694 .is_none()
695 );
696 });
697
698 // prediction is reported as rejected
699 let (reject_request, _) = requests.reject.next().await.unwrap();
700
701 assert_eq!(
702 &reject_request.rejections,
703 &[EditPredictionRejection {
704 request_id: id,
705 reason: EditPredictionRejectReason::InterpolatedEmpty,
706 was_shown: false
707 }]
708 );
709}
710
711const SIMPLE_DIFF: &str = indoc! { r"
712 --- a/root/foo.md
713 +++ b/root/foo.md
714 @@ ... @@
715 Hello!
716 -How
717 +How are you?
718 Bye
719"};
720
721#[gpui::test]
722async fn test_replace_current(cx: &mut TestAppContext) {
723 let (ep_store, mut requests) = init_test_with_fake_client(cx);
724 let fs = FakeFs::new(cx.executor());
725 fs.insert_tree(
726 "/root",
727 json!({
728 "foo.md": "Hello!\nHow\nBye\n"
729 }),
730 )
731 .await;
732 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
733
734 let buffer = project
735 .update(cx, |project, cx| {
736 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
737 project.open_buffer(path, cx)
738 })
739 .await
740 .unwrap();
741 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
742 let position = snapshot.anchor_before(language::Point::new(1, 3));
743
744 ep_store.update(cx, |ep_store, cx| {
745 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
746 });
747
748 let (request, respond_tx) = requests.predict.next().await.unwrap();
749 let first_response = model_response(&request, SIMPLE_DIFF);
750 let first_id = first_response.request_id.clone();
751 respond_tx.send(first_response).unwrap();
752
753 cx.run_until_parked();
754
755 ep_store.update(cx, |ep_store, cx| {
756 assert_eq!(
757 ep_store
758 .prediction_at(&buffer, None, &project, cx)
759 .unwrap()
760 .id
761 .0,
762 first_id
763 );
764 });
765
766 // a second request is triggered
767 ep_store.update(cx, |ep_store, cx| {
768 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
769 });
770
771 let (request, respond_tx) = requests.predict.next().await.unwrap();
772 let second_response = model_response(&request, SIMPLE_DIFF);
773 let second_id = second_response.request_id.clone();
774 respond_tx.send(second_response).unwrap();
775
776 cx.run_until_parked();
777
778 ep_store.update(cx, |ep_store, cx| {
779 // second replaces first
780 assert_eq!(
781 ep_store
782 .prediction_at(&buffer, None, &project, cx)
783 .unwrap()
784 .id
785 .0,
786 second_id
787 );
788 });
789
790 // first is reported as replaced
791 let (reject_request, _) = requests.reject.next().await.unwrap();
792
793 assert_eq!(
794 &reject_request.rejections,
795 &[EditPredictionRejection {
796 request_id: first_id,
797 reason: EditPredictionRejectReason::Replaced,
798 was_shown: false
799 }]
800 );
801}
802
803#[gpui::test]
804async fn test_current_preferred(cx: &mut TestAppContext) {
805 let (ep_store, mut requests) = init_test_with_fake_client(cx);
806 let fs = FakeFs::new(cx.executor());
807 fs.insert_tree(
808 "/root",
809 json!({
810 "foo.md": "Hello!\nHow\nBye\n"
811 }),
812 )
813 .await;
814 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
815
816 let buffer = project
817 .update(cx, |project, cx| {
818 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
819 project.open_buffer(path, cx)
820 })
821 .await
822 .unwrap();
823 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
824 let position = snapshot.anchor_before(language::Point::new(1, 3));
825
826 ep_store.update(cx, |ep_store, cx| {
827 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
828 });
829
830 let (request, respond_tx) = requests.predict.next().await.unwrap();
831 let first_response = model_response(&request, SIMPLE_DIFF);
832 let first_id = first_response.request_id.clone();
833 respond_tx.send(first_response).unwrap();
834
835 cx.run_until_parked();
836
837 ep_store.update(cx, |ep_store, cx| {
838 assert_eq!(
839 ep_store
840 .prediction_at(&buffer, None, &project, cx)
841 .unwrap()
842 .id
843 .0,
844 first_id
845 );
846 });
847
848 // a second request is triggered
849 ep_store.update(cx, |ep_store, cx| {
850 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
851 });
852
853 let (request, respond_tx) = requests.predict.next().await.unwrap();
854 // worse than current prediction
855 let second_response = model_response(
856 &request,
857 indoc! { r"
858 --- a/root/foo.md
859 +++ b/root/foo.md
860 @@ ... @@
861 Hello!
862 -How
863 +How are
864 Bye
865 "},
866 );
867 let second_id = second_response.request_id.clone();
868 respond_tx.send(second_response).unwrap();
869
870 cx.run_until_parked();
871
872 ep_store.update(cx, |ep_store, cx| {
873 // first is preferred over second
874 assert_eq!(
875 ep_store
876 .prediction_at(&buffer, None, &project, cx)
877 .unwrap()
878 .id
879 .0,
880 first_id
881 );
882 });
883
884 // second is reported as rejected
885 let (reject_request, _) = requests.reject.next().await.unwrap();
886
887 assert_eq!(
888 &reject_request.rejections,
889 &[EditPredictionRejection {
890 request_id: second_id,
891 reason: EditPredictionRejectReason::CurrentPreferred,
892 was_shown: false
893 }]
894 );
895}
896
897#[gpui::test]
898async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
899 let (ep_store, mut requests) = init_test_with_fake_client(cx);
900 let fs = FakeFs::new(cx.executor());
901 fs.insert_tree(
902 "/root",
903 json!({
904 "foo.md": "Hello!\nHow\nBye\n"
905 }),
906 )
907 .await;
908 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
909
910 let buffer = project
911 .update(cx, |project, cx| {
912 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
913 project.open_buffer(path, cx)
914 })
915 .await
916 .unwrap();
917 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
918 let position = snapshot.anchor_before(language::Point::new(1, 3));
919
920 // start two refresh tasks
921 ep_store.update(cx, |ep_store, cx| {
922 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
923 });
924
925 let (request1, respond_first) = requests.predict.next().await.unwrap();
926
927 ep_store.update(cx, |ep_store, cx| {
928 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
929 });
930
931 let (request, respond_second) = requests.predict.next().await.unwrap();
932
933 // wait for throttle
934 cx.run_until_parked();
935
936 // second responds first
937 let second_response = model_response(&request, SIMPLE_DIFF);
938 let second_id = second_response.request_id.clone();
939 respond_second.send(second_response).unwrap();
940
941 cx.run_until_parked();
942
943 ep_store.update(cx, |ep_store, cx| {
944 // current prediction is second
945 assert_eq!(
946 ep_store
947 .prediction_at(&buffer, None, &project, cx)
948 .unwrap()
949 .id
950 .0,
951 second_id
952 );
953 });
954
955 let first_response = model_response(&request1, SIMPLE_DIFF);
956 let first_id = first_response.request_id.clone();
957 respond_first.send(first_response).unwrap();
958
959 cx.run_until_parked();
960
961 ep_store.update(cx, |ep_store, cx| {
962 // current prediction is still second, since first was cancelled
963 assert_eq!(
964 ep_store
965 .prediction_at(&buffer, None, &project, cx)
966 .unwrap()
967 .id
968 .0,
969 second_id
970 );
971 });
972
973 // first is reported as rejected
974 let (reject_request, _) = requests.reject.next().await.unwrap();
975
976 cx.run_until_parked();
977
978 assert_eq!(
979 &reject_request.rejections,
980 &[EditPredictionRejection {
981 request_id: first_id,
982 reason: EditPredictionRejectReason::Canceled,
983 was_shown: false
984 }]
985 );
986}
987
988#[gpui::test]
989async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
990 let (ep_store, mut requests) = init_test_with_fake_client(cx);
991 let fs = FakeFs::new(cx.executor());
992 fs.insert_tree(
993 "/root",
994 json!({
995 "foo.md": "Hello!\nHow\nBye\n"
996 }),
997 )
998 .await;
999 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1000
1001 let buffer = project
1002 .update(cx, |project, cx| {
1003 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1004 project.open_buffer(path, cx)
1005 })
1006 .await
1007 .unwrap();
1008 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1009 let position = snapshot.anchor_before(language::Point::new(1, 3));
1010
1011 // start two refresh tasks
1012 ep_store.update(cx, |ep_store, cx| {
1013 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1014 });
1015
1016 let (request1, respond_first) = requests.predict.next().await.unwrap();
1017
1018 ep_store.update(cx, |ep_store, cx| {
1019 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1020 });
1021
1022 let (request2, respond_second) = requests.predict.next().await.unwrap();
1023
1024 // wait for throttle, so requests are sent
1025 cx.run_until_parked();
1026
1027 ep_store.update(cx, |ep_store, cx| {
1028 // start a third request
1029 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1030
1031 // 2 are pending, so 2nd is cancelled
1032 assert_eq!(
1033 ep_store
1034 .get_or_init_project(&project, cx)
1035 .cancelled_predictions
1036 .iter()
1037 .copied()
1038 .collect::<Vec<_>>(),
1039 [1]
1040 );
1041 });
1042
1043 // wait for throttle
1044 cx.run_until_parked();
1045
1046 let (request3, respond_third) = requests.predict.next().await.unwrap();
1047
1048 let first_response = model_response(&request1, SIMPLE_DIFF);
1049 let first_id = first_response.request_id.clone();
1050 respond_first.send(first_response).unwrap();
1051
1052 cx.run_until_parked();
1053
1054 ep_store.update(cx, |ep_store, cx| {
1055 // current prediction is first
1056 assert_eq!(
1057 ep_store
1058 .prediction_at(&buffer, None, &project, cx)
1059 .unwrap()
1060 .id
1061 .0,
1062 first_id
1063 );
1064 });
1065
1066 let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1067 let cancelled_id = cancelled_response.request_id.clone();
1068 respond_second.send(cancelled_response).unwrap();
1069
1070 cx.run_until_parked();
1071
1072 ep_store.update(cx, |ep_store, cx| {
1073 // current prediction is still first, since second was cancelled
1074 assert_eq!(
1075 ep_store
1076 .prediction_at(&buffer, None, &project, cx)
1077 .unwrap()
1078 .id
1079 .0,
1080 first_id
1081 );
1082 });
1083
1084 let third_response = model_response(&request3, SIMPLE_DIFF);
1085 let third_response_id = third_response.request_id.clone();
1086 respond_third.send(third_response).unwrap();
1087
1088 cx.run_until_parked();
1089
1090 ep_store.update(cx, |ep_store, cx| {
1091 // third completes and replaces first
1092 assert_eq!(
1093 ep_store
1094 .prediction_at(&buffer, None, &project, cx)
1095 .unwrap()
1096 .id
1097 .0,
1098 third_response_id
1099 );
1100 });
1101
1102 // second is reported as rejected
1103 let (reject_request, _) = requests.reject.next().await.unwrap();
1104
1105 cx.run_until_parked();
1106
1107 assert_eq!(
1108 &reject_request.rejections,
1109 &[
1110 EditPredictionRejection {
1111 request_id: cancelled_id,
1112 reason: EditPredictionRejectReason::Canceled,
1113 was_shown: false
1114 },
1115 EditPredictionRejection {
1116 request_id: first_id,
1117 reason: EditPredictionRejectReason::Replaced,
1118 was_shown: false
1119 }
1120 ]
1121 );
1122}
1123
1124#[gpui::test]
1125async fn test_rejections_flushing(cx: &mut TestAppContext) {
1126 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1127
1128 ep_store.update(cx, |ep_store, _cx| {
1129 ep_store.reject_prediction(
1130 EditPredictionId("test-1".into()),
1131 EditPredictionRejectReason::Discarded,
1132 false,
1133 );
1134 ep_store.reject_prediction(
1135 EditPredictionId("test-2".into()),
1136 EditPredictionRejectReason::Canceled,
1137 true,
1138 );
1139 });
1140
1141 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1142 cx.run_until_parked();
1143
1144 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1145 respond_tx.send(()).unwrap();
1146
1147 // batched
1148 assert_eq!(reject_request.rejections.len(), 2);
1149 assert_eq!(
1150 reject_request.rejections[0],
1151 EditPredictionRejection {
1152 request_id: "test-1".to_string(),
1153 reason: EditPredictionRejectReason::Discarded,
1154 was_shown: false
1155 }
1156 );
1157 assert_eq!(
1158 reject_request.rejections[1],
1159 EditPredictionRejection {
1160 request_id: "test-2".to_string(),
1161 reason: EditPredictionRejectReason::Canceled,
1162 was_shown: true
1163 }
1164 );
1165
1166 // Reaching batch size limit sends without debounce
1167 ep_store.update(cx, |ep_store, _cx| {
1168 for i in 0..70 {
1169 ep_store.reject_prediction(
1170 EditPredictionId(format!("batch-{}", i).into()),
1171 EditPredictionRejectReason::Discarded,
1172 false,
1173 );
1174 }
1175 });
1176
1177 // First MAX/2 items are sent immediately
1178 cx.run_until_parked();
1179 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1180 respond_tx.send(()).unwrap();
1181
1182 assert_eq!(reject_request.rejections.len(), 50);
1183 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1184 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1185
1186 // Remaining items are debounced with the next batch
1187 cx.executor().advance_clock(Duration::from_secs(15));
1188 cx.run_until_parked();
1189
1190 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1191 respond_tx.send(()).unwrap();
1192
1193 assert_eq!(reject_request.rejections.len(), 20);
1194 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1195 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1196
1197 // Request failure
1198 ep_store.update(cx, |ep_store, _cx| {
1199 ep_store.reject_prediction(
1200 EditPredictionId("retry-1".into()),
1201 EditPredictionRejectReason::Discarded,
1202 false,
1203 );
1204 });
1205
1206 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1207 cx.run_until_parked();
1208
1209 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1210 assert_eq!(reject_request.rejections.len(), 1);
1211 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1212 // Simulate failure
1213 drop(_respond_tx);
1214
1215 // Add another rejection
1216 ep_store.update(cx, |ep_store, _cx| {
1217 ep_store.reject_prediction(
1218 EditPredictionId("retry-2".into()),
1219 EditPredictionRejectReason::Discarded,
1220 false,
1221 );
1222 });
1223
1224 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1225 cx.run_until_parked();
1226
1227 // Retry should include both the failed item and the new one
1228 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1229 respond_tx.send(()).unwrap();
1230
1231 assert_eq!(reject_request.rejections.len(), 2);
1232 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1233 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1234}
1235
1236// Skipped until we start including diagnostics in prompt
1237// #[gpui::test]
1238// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1239// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1240// let fs = FakeFs::new(cx.executor());
1241// fs.insert_tree(
1242// "/root",
1243// json!({
1244// "foo.md": "Hello!\nBye"
1245// }),
1246// )
1247// .await;
1248// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1249
1250// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1251// let diagnostic = lsp::Diagnostic {
1252// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1253// severity: Some(lsp::DiagnosticSeverity::ERROR),
1254// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1255// ..Default::default()
1256// };
1257
1258// project.update(cx, |project, cx| {
1259// project.lsp_store().update(cx, |lsp_store, cx| {
1260// // Create some diagnostics
1261// lsp_store
1262// .update_diagnostics(
1263// LanguageServerId(0),
1264// lsp::PublishDiagnosticsParams {
1265// uri: path_to_buffer_uri.clone(),
1266// diagnostics: vec![diagnostic],
1267// version: None,
1268// },
1269// None,
1270// language::DiagnosticSourceKind::Pushed,
1271// &[],
1272// cx,
1273// )
1274// .unwrap();
1275// });
1276// });
1277
1278// let buffer = project
1279// .update(cx, |project, cx| {
1280// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1281// project.open_buffer(path, cx)
1282// })
1283// .await
1284// .unwrap();
1285
1286// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1287// let position = snapshot.anchor_before(language::Point::new(0, 0));
1288
1289// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1290// ep_store.request_prediction(&project, &buffer, position, cx)
1291// });
1292
1293// let (request, _respond_tx) = req_rx.next().await.unwrap();
1294
1295// assert_eq!(request.diagnostic_groups.len(), 1);
1296// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1297// .unwrap();
1298// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1299// assert_eq!(
1300// value,
1301// json!({
1302// "entries": [{
1303// "range": {
1304// "start": 8,
1305// "end": 10
1306// },
1307// "diagnostic": {
1308// "source": null,
1309// "code": null,
1310// "code_description": null,
1311// "severity": 1,
1312// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1313// "markdown": null,
1314// "group_id": 0,
1315// "is_primary": true,
1316// "is_disk_based": false,
1317// "is_unnecessary": false,
1318// "source_kind": "Pushed",
1319// "data": null,
1320// "underline": true
1321// }
1322// }],
1323// "primary_ix": 0
1324// })
1325// );
1326// }
1327
1328// Generate a model response that would apply the given diff to the active file.
1329fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1330 let excerpt =
1331 request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1332 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1333
1334 PredictEditsV3Response {
1335 request_id: Uuid::new_v4().to_string(),
1336 output: new_excerpt,
1337 }
1338}
1339
1340fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1341 zeta_prompt::format_zeta_prompt(&request.input, request.prompt_version)
1342}
1343
1344struct RequestChannels {
1345 predict: mpsc::UnboundedReceiver<(
1346 PredictEditsV3Request,
1347 oneshot::Sender<PredictEditsV3Response>,
1348 )>,
1349 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1350}
1351
1352fn init_test_with_fake_client(
1353 cx: &mut TestAppContext,
1354) -> (Entity<EditPredictionStore>, RequestChannels) {
1355 cx.update(move |cx| {
1356 let settings_store = SettingsStore::test(cx);
1357 cx.set_global(settings_store);
1358 zlog::init_test();
1359
1360 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1361 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1362
1363 let http_client = FakeHttpClient::create({
1364 move |req| {
1365 let uri = req.uri().path().to_string();
1366 let mut body = req.into_body();
1367 let predict_req_tx = predict_req_tx.clone();
1368 let reject_req_tx = reject_req_tx.clone();
1369 async move {
1370 let resp = match uri.as_str() {
1371 "/client/llm_tokens" => serde_json::to_string(&json!({
1372 "token": "test"
1373 }))
1374 .unwrap(),
1375 "/predict_edits/v3" => {
1376 let mut buf = Vec::new();
1377 body.read_to_end(&mut buf).await.ok();
1378 let req = serde_json::from_slice(&buf).unwrap();
1379
1380 let (res_tx, res_rx) = oneshot::channel();
1381 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1382 serde_json::to_string(&res_rx.await?).unwrap()
1383 }
1384 "/predict_edits/reject" => {
1385 let mut buf = Vec::new();
1386 body.read_to_end(&mut buf).await.ok();
1387 let req = serde_json::from_slice(&buf).unwrap();
1388
1389 let (res_tx, res_rx) = oneshot::channel();
1390 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1391 serde_json::to_string(&res_rx.await?).unwrap()
1392 }
1393 _ => {
1394 panic!("Unexpected path: {}", uri)
1395 }
1396 };
1397
1398 Ok(Response::builder().body(resp.into()).unwrap())
1399 }
1400 }
1401 });
1402
1403 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1404 client.cloud_client().set_credentials(1, "test".into());
1405
1406 language_model::init(client.clone(), cx);
1407
1408 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1409 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1410
1411 (
1412 ep_store,
1413 RequestChannels {
1414 predict: predict_req_rx,
1415 reject: reject_req_rx,
1416 },
1417 )
1418 })
1419}
1420
1421const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1422
1423#[gpui::test]
1424async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1425 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1426 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1427 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1428 });
1429
1430 let edit_preview = cx
1431 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1432 .await;
1433
1434 let prediction = EditPrediction {
1435 edits,
1436 edit_preview,
1437 buffer: buffer.clone(),
1438 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1439 id: EditPredictionId("the-id".into()),
1440 inputs: ZetaPromptInput {
1441 events: Default::default(),
1442 related_files: Default::default(),
1443 cursor_path: Path::new("").into(),
1444 cursor_excerpt: "".into(),
1445 editable_range_in_excerpt: 0..0,
1446 cursor_offset_in_excerpt: 0,
1447 },
1448 buffer_snapshotted_at: Instant::now(),
1449 response_received_at: Instant::now(),
1450 };
1451
1452 cx.update(|cx| {
1453 assert_eq!(
1454 from_completion_edits(
1455 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1456 &buffer,
1457 cx
1458 ),
1459 vec![(2..5, "REM".into()), (9..11, "".into())]
1460 );
1461
1462 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1463 assert_eq!(
1464 from_completion_edits(
1465 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1466 &buffer,
1467 cx
1468 ),
1469 vec![(2..2, "REM".into()), (6..8, "".into())]
1470 );
1471
1472 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1473 assert_eq!(
1474 from_completion_edits(
1475 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1476 &buffer,
1477 cx
1478 ),
1479 vec![(2..5, "REM".into()), (9..11, "".into())]
1480 );
1481
1482 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1483 assert_eq!(
1484 from_completion_edits(
1485 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1486 &buffer,
1487 cx
1488 ),
1489 vec![(3..3, "EM".into()), (7..9, "".into())]
1490 );
1491
1492 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1493 assert_eq!(
1494 from_completion_edits(
1495 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1496 &buffer,
1497 cx
1498 ),
1499 vec![(4..4, "M".into()), (8..10, "".into())]
1500 );
1501
1502 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1503 assert_eq!(
1504 from_completion_edits(
1505 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1506 &buffer,
1507 cx
1508 ),
1509 vec![(9..11, "".into())]
1510 );
1511
1512 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1513 assert_eq!(
1514 from_completion_edits(
1515 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1516 &buffer,
1517 cx
1518 ),
1519 vec![(4..4, "M".into()), (8..10, "".into())]
1520 );
1521
1522 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1523 assert_eq!(
1524 from_completion_edits(
1525 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1526 &buffer,
1527 cx
1528 ),
1529 vec![(4..4, "M".into())]
1530 );
1531
1532 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1533 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1534 })
1535}
1536
1537#[gpui::test]
1538async fn test_clean_up_diff(cx: &mut TestAppContext) {
1539 init_test(cx);
1540
1541 assert_eq!(
1542 apply_edit_prediction(
1543 indoc! {"
1544 fn main() {
1545 let word_1 = \"lorem\";
1546 let range = word.len()..word.len();
1547 }
1548 "},
1549 indoc! {"
1550 <|editable_region_start|>
1551 fn main() {
1552 let word_1 = \"lorem\";
1553 let range = word_1.len()..word_1.len();
1554 }
1555
1556 <|editable_region_end|>
1557 "},
1558 cx,
1559 )
1560 .await,
1561 indoc! {"
1562 fn main() {
1563 let word_1 = \"lorem\";
1564 let range = word_1.len()..word_1.len();
1565 }
1566 "},
1567 );
1568
1569 assert_eq!(
1570 apply_edit_prediction(
1571 indoc! {"
1572 fn main() {
1573 let story = \"the quick\"
1574 }
1575 "},
1576 indoc! {"
1577 <|editable_region_start|>
1578 fn main() {
1579 let story = \"the quick brown fox jumps over the lazy dog\";
1580 }
1581
1582 <|editable_region_end|>
1583 "},
1584 cx,
1585 )
1586 .await,
1587 indoc! {"
1588 fn main() {
1589 let story = \"the quick brown fox jumps over the lazy dog\";
1590 }
1591 "},
1592 );
1593}
1594
1595#[gpui::test]
1596async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1597 init_test(cx);
1598
1599 let buffer_content = "lorem\n";
1600 let completion_response = indoc! {"
1601 ```animals.js
1602 <|start_of_file|>
1603 <|editable_region_start|>
1604 lorem
1605 ipsum
1606 <|editable_region_end|>
1607 ```"};
1608
1609 assert_eq!(
1610 apply_edit_prediction(buffer_content, completion_response, cx).await,
1611 "lorem\nipsum"
1612 );
1613}
1614
1615#[gpui::test]
1616async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1617 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1618 // When the buffer ends without a trailing newline, but the model returns output
1619 // with a trailing newline, zeta2 should normalize both sides before diffing
1620 // so no spurious newline is inserted.
1621 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1622 let fs = FakeFs::new(cx.executor());
1623
1624 // Single line buffer with no trailing newline
1625 fs.insert_tree(
1626 "/root",
1627 json!({
1628 "foo.txt": "hello"
1629 }),
1630 )
1631 .await;
1632 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1633
1634 let buffer = project
1635 .update(cx, |project, cx| {
1636 let path = project
1637 .find_project_path(path!("root/foo.txt"), cx)
1638 .unwrap();
1639 project.open_buffer(path, cx)
1640 })
1641 .await
1642 .unwrap();
1643
1644 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1645 let position = snapshot.anchor_before(language::Point::new(0, 5));
1646
1647 ep_store.update(cx, |ep_store, cx| {
1648 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1649 });
1650
1651 let (_request, respond_tx) = requests.predict.next().await.unwrap();
1652
1653 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1654 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1655 let response = PredictEditsV3Response {
1656 request_id: Uuid::new_v4().to_string(),
1657 output: "hello world\n".to_string(),
1658 };
1659 respond_tx.send(response).unwrap();
1660
1661 cx.run_until_parked();
1662
1663 // The prediction should insert " world" without adding a newline
1664 ep_store.update(cx, |ep_store, cx| {
1665 let prediction = ep_store
1666 .prediction_at(&buffer, None, &project, cx)
1667 .expect("should have prediction");
1668 let edits: Vec<_> = prediction
1669 .edits
1670 .iter()
1671 .map(|(range, text)| {
1672 let snapshot = buffer.read(cx).snapshot();
1673 (range.to_offset(&snapshot), text.clone())
1674 })
1675 .collect();
1676 assert_eq!(edits, vec![(5..5, " world".into())]);
1677 });
1678}
1679
1680#[gpui::test]
1681async fn test_can_collect_data(cx: &mut TestAppContext) {
1682 init_test(cx);
1683
1684 let fs = project::FakeFs::new(cx.executor());
1685 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1686 .await;
1687
1688 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1689 let buffer = project
1690 .update(cx, |project, cx| {
1691 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1692 })
1693 .await
1694 .unwrap();
1695
1696 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1697 ep_store.update(cx, |ep_store, _cx| {
1698 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1699 });
1700
1701 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1702 assert_eq!(
1703 captured_request.lock().clone().unwrap().can_collect_data,
1704 true
1705 );
1706
1707 ep_store.update(cx, |ep_store, _cx| {
1708 ep_store.data_collection_choice = DataCollectionChoice::Disabled
1709 });
1710
1711 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1712 assert_eq!(
1713 captured_request.lock().clone().unwrap().can_collect_data,
1714 false
1715 );
1716}
1717
1718#[gpui::test]
1719async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1720 init_test(cx);
1721
1722 let fs = project::FakeFs::new(cx.executor());
1723 let project = Project::test(fs.clone(), [], cx).await;
1724
1725 let buffer = cx.new(|_cx| {
1726 Buffer::remote(
1727 language::BufferId::new(1).unwrap(),
1728 ReplicaId::new(1),
1729 language::Capability::ReadWrite,
1730 "fn main() {\n println!(\"Hello\");\n}",
1731 )
1732 });
1733
1734 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1735 ep_store.update(cx, |ep_store, _cx| {
1736 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1737 });
1738
1739 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1740 assert_eq!(
1741 captured_request.lock().clone().unwrap().can_collect_data,
1742 false
1743 );
1744}
1745
1746#[gpui::test]
1747async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1748 init_test(cx);
1749
1750 let fs = project::FakeFs::new(cx.executor());
1751 fs.insert_tree(
1752 path!("/project"),
1753 json!({
1754 "LICENSE": BSD_0_TXT,
1755 ".env": "SECRET_KEY=secret"
1756 }),
1757 )
1758 .await;
1759
1760 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1761 let buffer = project
1762 .update(cx, |project, cx| {
1763 project.open_local_buffer("/project/.env", cx)
1764 })
1765 .await
1766 .unwrap();
1767
1768 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1769 ep_store.update(cx, |ep_store, _cx| {
1770 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1771 });
1772
1773 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1774 assert_eq!(
1775 captured_request.lock().clone().unwrap().can_collect_data,
1776 false
1777 );
1778}
1779
1780#[gpui::test]
1781async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1782 init_test(cx);
1783
1784 let fs = project::FakeFs::new(cx.executor());
1785 let project = Project::test(fs.clone(), [], cx).await;
1786 let buffer = cx.new(|cx| Buffer::local("", cx));
1787
1788 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1789 ep_store.update(cx, |ep_store, _cx| {
1790 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1791 });
1792
1793 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1794 assert_eq!(
1795 captured_request.lock().clone().unwrap().can_collect_data,
1796 false
1797 );
1798}
1799
1800#[gpui::test]
1801async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1802 init_test(cx);
1803
1804 let fs = project::FakeFs::new(cx.executor());
1805 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1806 .await;
1807
1808 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1809 let buffer = project
1810 .update(cx, |project, cx| {
1811 project.open_local_buffer("/project/main.rs", cx)
1812 })
1813 .await
1814 .unwrap();
1815
1816 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1817 ep_store.update(cx, |ep_store, _cx| {
1818 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1819 });
1820
1821 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1822 assert_eq!(
1823 captured_request.lock().clone().unwrap().can_collect_data,
1824 false
1825 );
1826}
1827
1828#[gpui::test]
1829async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1830 init_test(cx);
1831
1832 let fs = project::FakeFs::new(cx.executor());
1833 fs.insert_tree(
1834 path!("/open_source_worktree"),
1835 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1836 )
1837 .await;
1838 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1839 .await;
1840
1841 let project = Project::test(
1842 fs.clone(),
1843 [
1844 path!("/open_source_worktree").as_ref(),
1845 path!("/closed_source_worktree").as_ref(),
1846 ],
1847 cx,
1848 )
1849 .await;
1850 let buffer = project
1851 .update(cx, |project, cx| {
1852 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1853 })
1854 .await
1855 .unwrap();
1856
1857 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1858 ep_store.update(cx, |ep_store, _cx| {
1859 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1860 });
1861
1862 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1863 assert_eq!(
1864 captured_request.lock().clone().unwrap().can_collect_data,
1865 true
1866 );
1867
1868 let closed_source_file = project
1869 .update(cx, |project, cx| {
1870 let worktree2 = project
1871 .worktree_for_root_name("closed_source_worktree", cx)
1872 .unwrap();
1873 worktree2.update(cx, |worktree2, cx| {
1874 worktree2.load_file(rel_path("main.rs"), cx)
1875 })
1876 })
1877 .await
1878 .unwrap()
1879 .file;
1880
1881 buffer.update(cx, |buffer, cx| {
1882 buffer.file_updated(closed_source_file, cx);
1883 });
1884
1885 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1886 assert_eq!(
1887 captured_request.lock().clone().unwrap().can_collect_data,
1888 false
1889 );
1890}
1891
1892#[gpui::test]
1893async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1894 init_test(cx);
1895
1896 let fs = project::FakeFs::new(cx.executor());
1897 fs.insert_tree(
1898 path!("/worktree1"),
1899 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1900 )
1901 .await;
1902 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1903 .await;
1904
1905 let project = Project::test(
1906 fs.clone(),
1907 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1908 cx,
1909 )
1910 .await;
1911 let buffer = project
1912 .update(cx, |project, cx| {
1913 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1914 })
1915 .await
1916 .unwrap();
1917 let private_buffer = project
1918 .update(cx, |project, cx| {
1919 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1920 })
1921 .await
1922 .unwrap();
1923
1924 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1925 ep_store.update(cx, |ep_store, _cx| {
1926 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1927 });
1928
1929 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1930 assert_eq!(
1931 captured_request.lock().clone().unwrap().can_collect_data,
1932 true
1933 );
1934
1935 // this has a side effect of registering the buffer to watch for edits
1936 run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1937 assert_eq!(
1938 captured_request.lock().clone().unwrap().can_collect_data,
1939 false
1940 );
1941
1942 private_buffer.update(cx, |private_buffer, cx| {
1943 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1944 });
1945
1946 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1947 assert_eq!(
1948 captured_request.lock().clone().unwrap().can_collect_data,
1949 false
1950 );
1951
1952 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1953 // included
1954 buffer.update(cx, |buffer, cx| {
1955 buffer.edit(
1956 [(
1957 0..0,
1958 " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1959 )],
1960 None,
1961 cx,
1962 );
1963 });
1964
1965 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1966 assert_eq!(
1967 captured_request.lock().clone().unwrap().can_collect_data,
1968 true
1969 );
1970}
1971
1972fn init_test(cx: &mut TestAppContext) {
1973 cx.update(|cx| {
1974 let settings_store = SettingsStore::test(cx);
1975 cx.set_global(settings_store);
1976 });
1977}
1978
1979async fn apply_edit_prediction(
1980 buffer_content: &str,
1981 completion_response: &str,
1982 cx: &mut TestAppContext,
1983) -> String {
1984 let fs = project::FakeFs::new(cx.executor());
1985 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1986 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1987 let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1988 *response.lock() = completion_response.to_string();
1989 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1990 buffer.update(cx, |buffer, cx| {
1991 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1992 });
1993 buffer.read_with(cx, |buffer, _| buffer.text())
1994}
1995
1996async fn run_edit_prediction(
1997 buffer: &Entity<Buffer>,
1998 project: &Entity<Project>,
1999 ep_store: &Entity<EditPredictionStore>,
2000 cx: &mut TestAppContext,
2001) -> EditPrediction {
2002 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2003 ep_store.update(cx, |ep_store, cx| {
2004 ep_store.register_buffer(buffer, &project, cx)
2005 });
2006 cx.background_executor.run_until_parked();
2007 let prediction_task = ep_store.update(cx, |ep_store, cx| {
2008 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2009 });
2010 prediction_task.await.unwrap().unwrap().prediction.unwrap()
2011}
2012
2013async fn make_test_ep_store(
2014 project: &Entity<Project>,
2015 cx: &mut TestAppContext,
2016) -> (
2017 Entity<EditPredictionStore>,
2018 Arc<Mutex<Option<PredictEditsBody>>>,
2019 Arc<Mutex<String>>,
2020) {
2021 let default_response = indoc! {"
2022 ```main.rs
2023 <|start_of_file|>
2024 <|editable_region_start|>
2025 hello world
2026 <|editable_region_end|>
2027 ```"
2028 };
2029 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2030 let completion_response: Arc<Mutex<String>> =
2031 Arc::new(Mutex::new(default_response.to_string()));
2032 let http_client = FakeHttpClient::create({
2033 let captured_request = captured_request.clone();
2034 let completion_response = completion_response.clone();
2035 let mut next_request_id = 0;
2036 move |req| {
2037 let captured_request = captured_request.clone();
2038 let completion_response = completion_response.clone();
2039 async move {
2040 match (req.method(), req.uri().path()) {
2041 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2042 .status(200)
2043 .body(
2044 serde_json::to_string(&CreateLlmTokenResponse {
2045 token: LlmToken("the-llm-token".to_string()),
2046 })
2047 .unwrap()
2048 .into(),
2049 )
2050 .unwrap()),
2051 (&Method::POST, "/predict_edits/v2") => {
2052 let mut request_body = String::new();
2053 req.into_body().read_to_string(&mut request_body).await?;
2054 *captured_request.lock() =
2055 Some(serde_json::from_str(&request_body).unwrap());
2056 next_request_id += 1;
2057 Ok(http_client::Response::builder()
2058 .status(200)
2059 .body(
2060 serde_json::to_string(&PredictEditsResponse {
2061 request_id: format!("request-{next_request_id}"),
2062 output_excerpt: completion_response.lock().clone(),
2063 })
2064 .unwrap()
2065 .into(),
2066 )
2067 .unwrap())
2068 }
2069 _ => Ok(http_client::Response::builder()
2070 .status(404)
2071 .body("Not Found".into())
2072 .unwrap()),
2073 }
2074 }
2075 }
2076 });
2077
2078 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2079 cx.update(|cx| {
2080 RefreshLlmTokenListener::register(client.clone(), cx);
2081 });
2082 let _server = FakeServer::for_client(42, &client, cx).await;
2083
2084 let ep_store = cx.new(|cx| {
2085 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2086 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2087
2088 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2089 for worktree in worktrees {
2090 let worktree_id = worktree.read(cx).id();
2091 ep_store
2092 .get_or_init_project(project, cx)
2093 .license_detection_watchers
2094 .entry(worktree_id)
2095 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2096 }
2097
2098 ep_store
2099 });
2100
2101 (ep_store, captured_request, completion_response)
2102}
2103
2104fn to_completion_edits(
2105 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2106 buffer: &Entity<Buffer>,
2107 cx: &App,
2108) -> Vec<(Range<Anchor>, Arc<str>)> {
2109 let buffer = buffer.read(cx);
2110 iterator
2111 .into_iter()
2112 .map(|(range, text)| {
2113 (
2114 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2115 text,
2116 )
2117 })
2118 .collect()
2119}
2120
2121fn from_completion_edits(
2122 editor_edits: &[(Range<Anchor>, Arc<str>)],
2123 buffer: &Entity<Buffer>,
2124 cx: &App,
2125) -> Vec<(Range<usize>, Arc<str>)> {
2126 let buffer = buffer.read(cx);
2127 editor_edits
2128 .iter()
2129 .map(|(range, text)| {
2130 (
2131 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2132 text.clone(),
2133 )
2134 })
2135 .collect()
2136}
2137
2138#[gpui::test]
2139async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2140 init_test(cx);
2141
2142 let fs = FakeFs::new(cx.executor());
2143 fs.insert_tree(
2144 "/project",
2145 serde_json::json!({
2146 "main.rs": "fn main() {\n \n}\n"
2147 }),
2148 )
2149 .await;
2150
2151 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2152
2153 let http_client = FakeHttpClient::create(|_req| async move {
2154 Ok(gpui::http_client::Response::builder()
2155 .status(401)
2156 .body("Unauthorized".into())
2157 .unwrap())
2158 });
2159
2160 let client =
2161 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2162 cx.update(|cx| {
2163 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2164 });
2165
2166 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2167
2168 let buffer = project
2169 .update(cx, |project, cx| {
2170 let path = project
2171 .find_project_path(path!("/project/main.rs"), cx)
2172 .unwrap();
2173 project.open_buffer(path, cx)
2174 })
2175 .await
2176 .unwrap();
2177
2178 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2179 ep_store.update(cx, |ep_store, cx| {
2180 ep_store.register_buffer(&buffer, &project, cx)
2181 });
2182 cx.background_executor.run_until_parked();
2183
2184 let completion_task = ep_store.update(cx, |ep_store, cx| {
2185 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2186 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2187 });
2188
2189 let result = completion_task.await;
2190 assert!(
2191 result.is_err(),
2192 "Without authentication and without custom URL, prediction should fail"
2193 );
2194}
2195
2196#[gpui::test]
2197async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2198 init_test(cx);
2199
2200 let fs = FakeFs::new(cx.executor());
2201 fs.insert_tree(
2202 "/project",
2203 serde_json::json!({
2204 "main.rs": "fn main() {\n \n}\n"
2205 }),
2206 )
2207 .await;
2208
2209 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2210
2211 let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2212 let predict_called_clone = predict_called.clone();
2213
2214 let http_client = FakeHttpClient::create({
2215 move |req| {
2216 let uri = req.uri().path().to_string();
2217 let predict_called = predict_called_clone.clone();
2218 async move {
2219 if uri.contains("predict") {
2220 predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2221 Ok(gpui::http_client::Response::builder()
2222 .body(
2223 serde_json::to_string(&open_ai::Response {
2224 id: "test-123".to_string(),
2225 object: "chat.completion".to_string(),
2226 created: 0,
2227 model: "test".to_string(),
2228 usage: open_ai::Usage {
2229 prompt_tokens: 0,
2230 completion_tokens: 0,
2231 total_tokens: 0,
2232 },
2233 choices: vec![open_ai::Choice {
2234 index: 0,
2235 message: open_ai::RequestMessage::Assistant {
2236 content: Some(open_ai::MessageContent::Plain(
2237 indoc! {"
2238 ```main.rs
2239 <|start_of_file|>
2240 <|editable_region_start|>
2241 fn main() {
2242 println!(\"Hello, world!\");
2243 }
2244 <|editable_region_end|>
2245 ```
2246 "}
2247 .to_string(),
2248 )),
2249 tool_calls: vec![],
2250 },
2251 finish_reason: Some("stop".to_string()),
2252 }],
2253 })
2254 .unwrap()
2255 .into(),
2256 )
2257 .unwrap())
2258 } else {
2259 Ok(gpui::http_client::Response::builder()
2260 .status(401)
2261 .body("Unauthorized".into())
2262 .unwrap())
2263 }
2264 }
2265 }
2266 });
2267
2268 let client =
2269 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2270 cx.update(|cx| {
2271 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2272 });
2273
2274 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2275
2276 let buffer = project
2277 .update(cx, |project, cx| {
2278 let path = project
2279 .find_project_path(path!("/project/main.rs"), cx)
2280 .unwrap();
2281 project.open_buffer(path, cx)
2282 })
2283 .await
2284 .unwrap();
2285
2286 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2287 ep_store.update(cx, |ep_store, cx| {
2288 ep_store.register_buffer(&buffer, &project, cx)
2289 });
2290 cx.background_executor.run_until_parked();
2291
2292 let completion_task = ep_store.update(cx, |ep_store, cx| {
2293 ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2294 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2295 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2296 });
2297
2298 let _ = completion_task.await;
2299
2300 assert!(
2301 predict_called.load(std::sync::atomic::Ordering::SeqCst),
2302 "With custom URL, predict endpoint should be called even without authentication"
2303 );
2304}
2305
2306#[gpui::test]
2307fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2308 let buffer = cx.new(|cx| {
2309 Buffer::local(
2310 indoc! {"
2311 zero
2312 one
2313 two
2314 three
2315 four
2316 five
2317 six
2318 seven
2319 eight
2320 nine
2321 ten
2322 eleven
2323 twelve
2324 thirteen
2325 fourteen
2326 fifteen
2327 sixteen
2328 seventeen
2329 eighteen
2330 nineteen
2331 twenty
2332 twenty-one
2333 twenty-two
2334 twenty-three
2335 twenty-four
2336 "},
2337 cx,
2338 )
2339 });
2340
2341 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2342
2343 buffer.update(cx, |buffer, cx| {
2344 let point = Point::new(12, 0);
2345 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2346 let point = Point::new(8, 0);
2347 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2348 });
2349
2350 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2351
2352 let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2353
2354 assert_eq!(
2355 diff,
2356 indoc! {"
2357 @@ -6,10 +6,12 @@
2358 five
2359 six
2360 seven
2361 +FIRST INSERTION
2362 eight
2363 nine
2364 ten
2365 eleven
2366 +SECOND INSERTION
2367 twelve
2368 thirteen
2369 fourteen
2370 "}
2371 );
2372}
2373
2374#[ctor::ctor]
2375fn init_logger() {
2376 zlog::init_test();
2377}