1use super::*;
2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
3use client::{UserStore, test::FakeServer};
4use clock::{FakeSystemClock, ReplicaId};
5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
6use cloud_llm_client::{
7 EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
8 RejectEditPredictionsBody,
9 predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
10};
11use futures::{
12 AsyncReadExt, StreamExt,
13 channel::{mpsc, oneshot},
14};
15use gpui::{
16 Entity, TestAppContext,
17 http_client::{FakeHttpClient, Response},
18};
19use indoc::indoc;
20use language::Point;
21use lsp::LanguageServerId;
22use parking_lot::Mutex;
23use pretty_assertions::{assert_eq, assert_matches};
24use project::{FakeFs, Project};
25use serde_json::json;
26use settings::SettingsStore;
27use std::{path::Path, sync::Arc, time::Duration};
28use util::{path, rel_path::rel_path};
29use uuid::Uuid;
30use zeta_prompt::ZetaPromptInput;
31
32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
33
34#[gpui::test]
35async fn test_current_state(cx: &mut TestAppContext) {
36 let (ep_store, mut requests) = init_test_with_fake_client(cx);
37 let fs = FakeFs::new(cx.executor());
38 fs.insert_tree(
39 "/root",
40 json!({
41 "1.txt": "Hello!\nHow\nBye\n",
42 "2.txt": "Hola!\nComo\nAdios\n"
43 }),
44 )
45 .await;
46 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
47
48 let buffer1 = project
49 .update(cx, |project, cx| {
50 let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
51 project.set_active_path(Some(path.clone()), cx);
52 project.open_buffer(path, cx)
53 })
54 .await
55 .unwrap();
56 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
57 let position = snapshot1.anchor_before(language::Point::new(1, 3));
58
59 ep_store.update(cx, |ep_store, cx| {
60 ep_store.register_project(&project, cx);
61 ep_store.register_buffer(&buffer1, &project, cx);
62 });
63
64 // Prediction for current file
65
66 ep_store.update(cx, |ep_store, cx| {
67 ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
68 });
69 let (request, respond_tx) = requests.predict.next().await.unwrap();
70
71 respond_tx
72 .send(model_response(
73 &request,
74 indoc! {r"
75 --- a/root/1.txt
76 +++ b/root/1.txt
77 @@ ... @@
78 Hello!
79 -How
80 +How are you?
81 Bye
82 "},
83 ))
84 .unwrap();
85
86 cx.run_until_parked();
87
88 ep_store.update(cx, |ep_store, cx| {
89 let prediction = ep_store
90 .prediction_at(&buffer1, None, &project, cx)
91 .unwrap();
92 assert_matches!(prediction, BufferEditPrediction::Local { .. });
93 });
94
95 ep_store.update(cx, |ep_store, _cx| {
96 ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
97 });
98
99 // Prediction for diagnostic in another file
100
101 let diagnostic = lsp::Diagnostic {
102 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
103 severity: Some(lsp::DiagnosticSeverity::ERROR),
104 message: "Sentence is incomplete".to_string(),
105 ..Default::default()
106 };
107
108 project.update(cx, |project, cx| {
109 project.lsp_store().update(cx, |lsp_store, cx| {
110 lsp_store
111 .update_diagnostics(
112 LanguageServerId(0),
113 lsp::PublishDiagnosticsParams {
114 uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
115 diagnostics: vec![diagnostic],
116 version: None,
117 },
118 None,
119 language::DiagnosticSourceKind::Pushed,
120 &[],
121 cx,
122 )
123 .unwrap();
124 });
125 });
126
127 let (request, respond_tx) = requests.predict.next().await.unwrap();
128 respond_tx
129 .send(model_response(
130 &request,
131 indoc! {r#"
132 --- a/root/2.txt
133 +++ b/root/2.txt
134 @@ ... @@
135 Hola!
136 -Como
137 +Como estas?
138 Adios
139 "#},
140 ))
141 .unwrap();
142 cx.run_until_parked();
143
144 ep_store.update(cx, |ep_store, cx| {
145 let prediction = ep_store
146 .prediction_at(&buffer1, None, &project, cx)
147 .unwrap();
148 assert_matches!(
149 prediction,
150 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
151 );
152 });
153
154 let buffer2 = project
155 .update(cx, |project, cx| {
156 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
157 project.open_buffer(path, cx)
158 })
159 .await
160 .unwrap();
161
162 ep_store.update(cx, |ep_store, cx| {
163 let prediction = ep_store
164 .prediction_at(&buffer2, None, &project, cx)
165 .unwrap();
166 assert_matches!(prediction, BufferEditPrediction::Local { .. });
167 });
168}
169
170#[gpui::test]
171async fn test_simple_request(cx: &mut TestAppContext) {
172 let (ep_store, mut requests) = init_test_with_fake_client(cx);
173 let fs = FakeFs::new(cx.executor());
174 fs.insert_tree(
175 "/root",
176 json!({
177 "foo.md": "Hello!\nHow\nBye\n"
178 }),
179 )
180 .await;
181 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
182
183 let buffer = project
184 .update(cx, |project, cx| {
185 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
186 project.open_buffer(path, cx)
187 })
188 .await
189 .unwrap();
190 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
191 let position = snapshot.anchor_before(language::Point::new(1, 3));
192
193 let prediction_task = ep_store.update(cx, |ep_store, cx| {
194 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
195 });
196
197 let (request, respond_tx) = requests.predict.next().await.unwrap();
198
199 // TODO Put back when we have a structured request again
200 // assert_eq!(
201 // request.excerpt_path.as_ref(),
202 // Path::new(path!("root/foo.md"))
203 // );
204 // assert_eq!(
205 // request.cursor_point,
206 // Point {
207 // line: Line(1),
208 // column: 3
209 // }
210 // );
211
212 respond_tx
213 .send(model_response(
214 &request,
215 indoc! { r"
216 --- a/root/foo.md
217 +++ b/root/foo.md
218 @@ ... @@
219 Hello!
220 -How
221 +How are you?
222 Bye
223 "},
224 ))
225 .unwrap();
226
227 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
228
229 assert_eq!(prediction.edits.len(), 1);
230 assert_eq!(
231 prediction.edits[0].0.to_point(&snapshot).start,
232 language::Point::new(1, 3)
233 );
234 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
235}
236
237#[gpui::test]
238async fn test_request_events(cx: &mut TestAppContext) {
239 let (ep_store, mut requests) = init_test_with_fake_client(cx);
240 let fs = FakeFs::new(cx.executor());
241 fs.insert_tree(
242 "/root",
243 json!({
244 "foo.md": "Hello!\n\nBye\n"
245 }),
246 )
247 .await;
248 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
249
250 let buffer = project
251 .update(cx, |project, cx| {
252 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
253 project.open_buffer(path, cx)
254 })
255 .await
256 .unwrap();
257
258 ep_store.update(cx, |ep_store, cx| {
259 ep_store.register_buffer(&buffer, &project, cx);
260 });
261
262 buffer.update(cx, |buffer, cx| {
263 buffer.edit(vec![(7..7, "How")], None, cx);
264 });
265
266 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
267 let position = snapshot.anchor_before(language::Point::new(1, 3));
268
269 let prediction_task = ep_store.update(cx, |ep_store, cx| {
270 ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
271 });
272
273 let (request, respond_tx) = requests.predict.next().await.unwrap();
274
275 let prompt = prompt_from_request(&request);
276 assert!(
277 prompt.contains(indoc! {"
278 --- a/root/foo.md
279 +++ b/root/foo.md
280 @@ -1,3 +1,3 @@
281 Hello!
282 -
283 +How
284 Bye
285 "}),
286 "{prompt}"
287 );
288
289 respond_tx
290 .send(model_response(
291 &request,
292 indoc! {r#"
293 --- a/root/foo.md
294 +++ b/root/foo.md
295 @@ ... @@
296 Hello!
297 -How
298 +How are you?
299 Bye
300 "#},
301 ))
302 .unwrap();
303
304 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
305
306 assert_eq!(prediction.edits.len(), 1);
307 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
308}
309
310#[gpui::test]
311async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
312 let (ep_store, _requests) = init_test_with_fake_client(cx);
313 let fs = FakeFs::new(cx.executor());
314 fs.insert_tree(
315 "/root",
316 json!({
317 "foo.md": "Hello!\n\nBye\n"
318 }),
319 )
320 .await;
321 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
322
323 let buffer = project
324 .update(cx, |project, cx| {
325 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
326 project.open_buffer(path, cx)
327 })
328 .await
329 .unwrap();
330
331 ep_store.update(cx, |ep_store, cx| {
332 ep_store.register_buffer(&buffer, &project, cx);
333 });
334
335 // First burst: insert "How"
336 buffer.update(cx, |buffer, cx| {
337 buffer.edit(vec![(7..7, "How")], None, cx);
338 });
339
340 // Simulate a pause longer than the grouping threshold (e.g. 500ms).
341 cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
342 cx.run_until_parked();
343
344 // Second burst: append " are you?" immediately after "How" on the same line.
345 //
346 // Keeping both bursts on the same line ensures the existing line-span coalescing logic
347 // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
348 buffer.update(cx, |buffer, cx| {
349 buffer.edit(vec![(10..10, " are you?")], None, cx);
350 });
351
352 // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
353 // advanced after the pause boundary is recorded, making pause-splitting deterministic.
354 buffer.update(cx, |buffer, cx| {
355 buffer.edit(vec![(19..19, "!")], None, cx);
356 });
357
358 // Without time-based splitting, there is one event.
359 let events = ep_store.update(cx, |ep_store, cx| {
360 ep_store.edit_history_for_project(&project, cx)
361 });
362 assert_eq!(events.len(), 1);
363 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
364 assert_eq!(
365 diff.as_str(),
366 indoc! {"
367 @@ -1,3 +1,3 @@
368 Hello!
369 -
370 +How are you?!
371 Bye
372 "}
373 );
374
375 // With time-based splitting, there are two distinct events.
376 let events = ep_store.update(cx, |ep_store, cx| {
377 ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
378 });
379 assert_eq!(events.len(), 2);
380 let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
381 assert_eq!(
382 diff.as_str(),
383 indoc! {"
384 @@ -1,3 +1,3 @@
385 Hello!
386 -
387 +How
388 Bye
389 "}
390 );
391
392 let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
393 assert_eq!(
394 diff.as_str(),
395 indoc! {"
396 @@ -1,3 +1,3 @@
397 Hello!
398 -How
399 +How are you?!
400 Bye
401 "}
402 );
403}
404
405#[gpui::test]
406async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
407 let (ep_store, _requests) = init_test_with_fake_client(cx);
408 let fs = FakeFs::new(cx.executor());
409
410 // Create a file with 30 lines to test line-based coalescing
411 let content = (1..=30)
412 .map(|i| format!("Line {}\n", i))
413 .collect::<String>();
414 fs.insert_tree(
415 "/root",
416 json!({
417 "foo.md": content
418 }),
419 )
420 .await;
421 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
422
423 let buffer = project
424 .update(cx, |project, cx| {
425 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
426 project.open_buffer(path, cx)
427 })
428 .await
429 .unwrap();
430
431 ep_store.update(cx, |ep_store, cx| {
432 ep_store.register_buffer(&buffer, &project, cx);
433 });
434
435 // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
436 buffer.update(cx, |buffer, cx| {
437 let start = Point::new(10, 0).to_offset(buffer);
438 let end = Point::new(13, 0).to_offset(buffer);
439 buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
440 });
441
442 let events = ep_store.update(cx, |ep_store, cx| {
443 ep_store.edit_history_for_project(&project, cx)
444 });
445 assert_eq!(
446 render_events(&events),
447 indoc! {"
448 @@ -8,9 +8,8 @@
449 Line 8
450 Line 9
451 Line 10
452 -Line 11
453 -Line 12
454 -Line 13
455 +Middle A
456 +Middle B
457 Line 14
458 Line 15
459 Line 16
460 "},
461 "After first edit"
462 );
463
464 // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
465 // This tests that coalescing considers the START of the existing range
466 buffer.update(cx, |buffer, cx| {
467 let offset = Point::new(5, 0).to_offset(buffer);
468 buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
469 });
470
471 let events = ep_store.update(cx, |ep_store, cx| {
472 ep_store.edit_history_for_project(&project, cx)
473 });
474 assert_eq!(
475 render_events(&events),
476 indoc! {"
477 @@ -3,14 +3,14 @@
478 Line 3
479 Line 4
480 Line 5
481 +Above
482 Line 6
483 Line 7
484 Line 8
485 Line 9
486 Line 10
487 -Line 11
488 -Line 12
489 -Line 13
490 +Middle A
491 +Middle B
492 Line 14
493 Line 15
494 Line 16
495 "},
496 "After inserting above (should coalesce)"
497 );
498
499 // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
500 // This tests that coalescing considers the END of the existing range
501 buffer.update(cx, |buffer, cx| {
502 let offset = Point::new(14, 0).to_offset(buffer);
503 buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
504 });
505
506 let events = ep_store.update(cx, |ep_store, cx| {
507 ep_store.edit_history_for_project(&project, cx)
508 });
509 assert_eq!(
510 render_events(&events),
511 indoc! {"
512 @@ -3,15 +3,16 @@
513 Line 3
514 Line 4
515 Line 5
516 +Above
517 Line 6
518 Line 7
519 Line 8
520 Line 9
521 Line 10
522 -Line 11
523 -Line 12
524 -Line 13
525 +Middle A
526 +Middle B
527 Line 14
528 +Below
529 Line 15
530 Line 16
531 Line 17
532 "},
533 "After inserting below (should coalesce)"
534 );
535
536 // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
537 // This should NOT coalesce - creates a new event
538 buffer.update(cx, |buffer, cx| {
539 let offset = Point::new(25, 0).to_offset(buffer);
540 buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
541 });
542
543 let events = ep_store.update(cx, |ep_store, cx| {
544 ep_store.edit_history_for_project(&project, cx)
545 });
546 assert_eq!(
547 render_events(&events),
548 indoc! {"
549 @@ -3,15 +3,16 @@
550 Line 3
551 Line 4
552 Line 5
553 +Above
554 Line 6
555 Line 7
556 Line 8
557 Line 9
558 Line 10
559 -Line 11
560 -Line 12
561 -Line 13
562 +Middle A
563 +Middle B
564 Line 14
565 +Below
566 Line 15
567 Line 16
568 Line 17
569
570 ---
571 @@ -23,6 +23,7 @@
572 Line 22
573 Line 23
574 Line 24
575 +Far below
576 Line 25
577 Line 26
578 Line 27
579 "},
580 "After inserting far below (should NOT coalesce)"
581 );
582}
583
584fn render_events(events: &[StoredEvent]) -> String {
585 events
586 .iter()
587 .map(|e| {
588 let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
589 diff.as_str()
590 })
591 .collect::<Vec<_>>()
592 .join("\n---\n")
593}
594
595#[gpui::test]
596async fn test_empty_prediction(cx: &mut TestAppContext) {
597 let (ep_store, mut requests) = init_test_with_fake_client(cx);
598 let fs = FakeFs::new(cx.executor());
599 fs.insert_tree(
600 "/root",
601 json!({
602 "foo.md": "Hello!\nHow\nBye\n"
603 }),
604 )
605 .await;
606 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
607
608 let buffer = project
609 .update(cx, |project, cx| {
610 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
611 project.open_buffer(path, cx)
612 })
613 .await
614 .unwrap();
615 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
616 let position = snapshot.anchor_before(language::Point::new(1, 3));
617
618 ep_store.update(cx, |ep_store, cx| {
619 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
620 });
621
622 let (request, respond_tx) = requests.predict.next().await.unwrap();
623 let response = model_response(&request, "");
624 let id = response.request_id.clone();
625 respond_tx.send(response).unwrap();
626
627 cx.run_until_parked();
628
629 ep_store.update(cx, |ep_store, cx| {
630 assert!(
631 ep_store
632 .prediction_at(&buffer, None, &project, cx)
633 .is_none()
634 );
635 });
636
637 // prediction is reported as rejected
638 let (reject_request, _) = requests.reject.next().await.unwrap();
639
640 assert_eq!(
641 &reject_request.rejections,
642 &[EditPredictionRejection {
643 request_id: id,
644 reason: EditPredictionRejectReason::Empty,
645 was_shown: false
646 }]
647 );
648}
649
650#[gpui::test]
651async fn test_interpolated_empty(cx: &mut TestAppContext) {
652 let (ep_store, mut requests) = init_test_with_fake_client(cx);
653 let fs = FakeFs::new(cx.executor());
654 fs.insert_tree(
655 "/root",
656 json!({
657 "foo.md": "Hello!\nHow\nBye\n"
658 }),
659 )
660 .await;
661 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
662
663 let buffer = project
664 .update(cx, |project, cx| {
665 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
666 project.open_buffer(path, cx)
667 })
668 .await
669 .unwrap();
670 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
671 let position = snapshot.anchor_before(language::Point::new(1, 3));
672
673 ep_store.update(cx, |ep_store, cx| {
674 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
675 });
676
677 let (request, respond_tx) = requests.predict.next().await.unwrap();
678
679 buffer.update(cx, |buffer, cx| {
680 buffer.set_text("Hello!\nHow are you?\nBye", cx);
681 });
682
683 let response = model_response(&request, SIMPLE_DIFF);
684 let id = response.request_id.clone();
685 respond_tx.send(response).unwrap();
686
687 cx.run_until_parked();
688
689 ep_store.update(cx, |ep_store, cx| {
690 assert!(
691 ep_store
692 .prediction_at(&buffer, None, &project, cx)
693 .is_none()
694 );
695 });
696
697 // prediction is reported as rejected
698 let (reject_request, _) = requests.reject.next().await.unwrap();
699
700 assert_eq!(
701 &reject_request.rejections,
702 &[EditPredictionRejection {
703 request_id: id,
704 reason: EditPredictionRejectReason::InterpolatedEmpty,
705 was_shown: false
706 }]
707 );
708}
709
710const SIMPLE_DIFF: &str = indoc! { r"
711 --- a/root/foo.md
712 +++ b/root/foo.md
713 @@ ... @@
714 Hello!
715 -How
716 +How are you?
717 Bye
718"};
719
720#[gpui::test]
721async fn test_replace_current(cx: &mut TestAppContext) {
722 let (ep_store, mut requests) = init_test_with_fake_client(cx);
723 let fs = FakeFs::new(cx.executor());
724 fs.insert_tree(
725 "/root",
726 json!({
727 "foo.md": "Hello!\nHow\nBye\n"
728 }),
729 )
730 .await;
731 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
732
733 let buffer = project
734 .update(cx, |project, cx| {
735 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
736 project.open_buffer(path, cx)
737 })
738 .await
739 .unwrap();
740 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
741 let position = snapshot.anchor_before(language::Point::new(1, 3));
742
743 ep_store.update(cx, |ep_store, cx| {
744 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
745 });
746
747 let (request, respond_tx) = requests.predict.next().await.unwrap();
748 let first_response = model_response(&request, SIMPLE_DIFF);
749 let first_id = first_response.request_id.clone();
750 respond_tx.send(first_response).unwrap();
751
752 cx.run_until_parked();
753
754 ep_store.update(cx, |ep_store, cx| {
755 assert_eq!(
756 ep_store
757 .prediction_at(&buffer, None, &project, cx)
758 .unwrap()
759 .id
760 .0,
761 first_id
762 );
763 });
764
765 // a second request is triggered
766 ep_store.update(cx, |ep_store, cx| {
767 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
768 });
769
770 let (request, respond_tx) = requests.predict.next().await.unwrap();
771 let second_response = model_response(&request, SIMPLE_DIFF);
772 let second_id = second_response.request_id.clone();
773 respond_tx.send(second_response).unwrap();
774
775 cx.run_until_parked();
776
777 ep_store.update(cx, |ep_store, cx| {
778 // second replaces first
779 assert_eq!(
780 ep_store
781 .prediction_at(&buffer, None, &project, cx)
782 .unwrap()
783 .id
784 .0,
785 second_id
786 );
787 });
788
789 // first is reported as replaced
790 let (reject_request, _) = requests.reject.next().await.unwrap();
791
792 assert_eq!(
793 &reject_request.rejections,
794 &[EditPredictionRejection {
795 request_id: first_id,
796 reason: EditPredictionRejectReason::Replaced,
797 was_shown: false
798 }]
799 );
800}
801
802#[gpui::test]
803async fn test_current_preferred(cx: &mut TestAppContext) {
804 let (ep_store, mut requests) = init_test_with_fake_client(cx);
805 let fs = FakeFs::new(cx.executor());
806 fs.insert_tree(
807 "/root",
808 json!({
809 "foo.md": "Hello!\nHow\nBye\n"
810 }),
811 )
812 .await;
813 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
814
815 let buffer = project
816 .update(cx, |project, cx| {
817 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
818 project.open_buffer(path, cx)
819 })
820 .await
821 .unwrap();
822 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
823 let position = snapshot.anchor_before(language::Point::new(1, 3));
824
825 ep_store.update(cx, |ep_store, cx| {
826 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
827 });
828
829 let (request, respond_tx) = requests.predict.next().await.unwrap();
830 let first_response = model_response(&request, SIMPLE_DIFF);
831 let first_id = first_response.request_id.clone();
832 respond_tx.send(first_response).unwrap();
833
834 cx.run_until_parked();
835
836 ep_store.update(cx, |ep_store, cx| {
837 assert_eq!(
838 ep_store
839 .prediction_at(&buffer, None, &project, cx)
840 .unwrap()
841 .id
842 .0,
843 first_id
844 );
845 });
846
847 // a second request is triggered
848 ep_store.update(cx, |ep_store, cx| {
849 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
850 });
851
852 let (request, respond_tx) = requests.predict.next().await.unwrap();
853 // worse than current prediction
854 let second_response = model_response(
855 &request,
856 indoc! { r"
857 --- a/root/foo.md
858 +++ b/root/foo.md
859 @@ ... @@
860 Hello!
861 -How
862 +How are
863 Bye
864 "},
865 );
866 let second_id = second_response.request_id.clone();
867 respond_tx.send(second_response).unwrap();
868
869 cx.run_until_parked();
870
871 ep_store.update(cx, |ep_store, cx| {
872 // first is preferred over second
873 assert_eq!(
874 ep_store
875 .prediction_at(&buffer, None, &project, cx)
876 .unwrap()
877 .id
878 .0,
879 first_id
880 );
881 });
882
883 // second is reported as rejected
884 let (reject_request, _) = requests.reject.next().await.unwrap();
885
886 assert_eq!(
887 &reject_request.rejections,
888 &[EditPredictionRejection {
889 request_id: second_id,
890 reason: EditPredictionRejectReason::CurrentPreferred,
891 was_shown: false
892 }]
893 );
894}
895
896#[gpui::test]
897async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
898 let (ep_store, mut requests) = init_test_with_fake_client(cx);
899 let fs = FakeFs::new(cx.executor());
900 fs.insert_tree(
901 "/root",
902 json!({
903 "foo.md": "Hello!\nHow\nBye\n"
904 }),
905 )
906 .await;
907 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
908
909 let buffer = project
910 .update(cx, |project, cx| {
911 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
912 project.open_buffer(path, cx)
913 })
914 .await
915 .unwrap();
916 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
917 let position = snapshot.anchor_before(language::Point::new(1, 3));
918
919 // start two refresh tasks
920 ep_store.update(cx, |ep_store, cx| {
921 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
922 });
923
924 let (request1, respond_first) = requests.predict.next().await.unwrap();
925
926 ep_store.update(cx, |ep_store, cx| {
927 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
928 });
929
930 let (request, respond_second) = requests.predict.next().await.unwrap();
931
932 // wait for throttle
933 cx.run_until_parked();
934
935 // second responds first
936 let second_response = model_response(&request, SIMPLE_DIFF);
937 let second_id = second_response.request_id.clone();
938 respond_second.send(second_response).unwrap();
939
940 cx.run_until_parked();
941
942 ep_store.update(cx, |ep_store, cx| {
943 // current prediction is second
944 assert_eq!(
945 ep_store
946 .prediction_at(&buffer, None, &project, cx)
947 .unwrap()
948 .id
949 .0,
950 second_id
951 );
952 });
953
954 let first_response = model_response(&request1, SIMPLE_DIFF);
955 let first_id = first_response.request_id.clone();
956 respond_first.send(first_response).unwrap();
957
958 cx.run_until_parked();
959
960 ep_store.update(cx, |ep_store, cx| {
961 // current prediction is still second, since first was cancelled
962 assert_eq!(
963 ep_store
964 .prediction_at(&buffer, None, &project, cx)
965 .unwrap()
966 .id
967 .0,
968 second_id
969 );
970 });
971
972 // first is reported as rejected
973 let (reject_request, _) = requests.reject.next().await.unwrap();
974
975 cx.run_until_parked();
976
977 assert_eq!(
978 &reject_request.rejections,
979 &[EditPredictionRejection {
980 request_id: first_id,
981 reason: EditPredictionRejectReason::Canceled,
982 was_shown: false
983 }]
984 );
985}
986
987#[gpui::test]
988async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
989 let (ep_store, mut requests) = init_test_with_fake_client(cx);
990 let fs = FakeFs::new(cx.executor());
991 fs.insert_tree(
992 "/root",
993 json!({
994 "foo.md": "Hello!\nHow\nBye\n"
995 }),
996 )
997 .await;
998 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
999
1000 let buffer = project
1001 .update(cx, |project, cx| {
1002 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1003 project.open_buffer(path, cx)
1004 })
1005 .await
1006 .unwrap();
1007 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1008 let position = snapshot.anchor_before(language::Point::new(1, 3));
1009
1010 // start two refresh tasks
1011 ep_store.update(cx, |ep_store, cx| {
1012 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1013 });
1014
1015 let (request1, respond_first) = requests.predict.next().await.unwrap();
1016
1017 ep_store.update(cx, |ep_store, cx| {
1018 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1019 });
1020
1021 let (request2, respond_second) = requests.predict.next().await.unwrap();
1022
1023 // wait for throttle, so requests are sent
1024 cx.run_until_parked();
1025
1026 ep_store.update(cx, |ep_store, cx| {
1027 // start a third request
1028 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1029
1030 // 2 are pending, so 2nd is cancelled
1031 assert_eq!(
1032 ep_store
1033 .get_or_init_project(&project, cx)
1034 .cancelled_predictions
1035 .iter()
1036 .copied()
1037 .collect::<Vec<_>>(),
1038 [1]
1039 );
1040 });
1041
1042 // wait for throttle
1043 cx.run_until_parked();
1044
1045 let (request3, respond_third) = requests.predict.next().await.unwrap();
1046
1047 let first_response = model_response(&request1, SIMPLE_DIFF);
1048 let first_id = first_response.request_id.clone();
1049 respond_first.send(first_response).unwrap();
1050
1051 cx.run_until_parked();
1052
1053 ep_store.update(cx, |ep_store, cx| {
1054 // current prediction is first
1055 assert_eq!(
1056 ep_store
1057 .prediction_at(&buffer, None, &project, cx)
1058 .unwrap()
1059 .id
1060 .0,
1061 first_id
1062 );
1063 });
1064
1065 let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1066 let cancelled_id = cancelled_response.request_id.clone();
1067 respond_second.send(cancelled_response).unwrap();
1068
1069 cx.run_until_parked();
1070
1071 ep_store.update(cx, |ep_store, cx| {
1072 // current prediction is still first, since second was cancelled
1073 assert_eq!(
1074 ep_store
1075 .prediction_at(&buffer, None, &project, cx)
1076 .unwrap()
1077 .id
1078 .0,
1079 first_id
1080 );
1081 });
1082
1083 let third_response = model_response(&request3, SIMPLE_DIFF);
1084 let third_response_id = third_response.request_id.clone();
1085 respond_third.send(third_response).unwrap();
1086
1087 cx.run_until_parked();
1088
1089 ep_store.update(cx, |ep_store, cx| {
1090 // third completes and replaces first
1091 assert_eq!(
1092 ep_store
1093 .prediction_at(&buffer, None, &project, cx)
1094 .unwrap()
1095 .id
1096 .0,
1097 third_response_id
1098 );
1099 });
1100
1101 // second is reported as rejected
1102 let (reject_request, _) = requests.reject.next().await.unwrap();
1103
1104 cx.run_until_parked();
1105
1106 assert_eq!(
1107 &reject_request.rejections,
1108 &[
1109 EditPredictionRejection {
1110 request_id: cancelled_id,
1111 reason: EditPredictionRejectReason::Canceled,
1112 was_shown: false
1113 },
1114 EditPredictionRejection {
1115 request_id: first_id,
1116 reason: EditPredictionRejectReason::Replaced,
1117 was_shown: false
1118 }
1119 ]
1120 );
1121}
1122
1123#[gpui::test]
1124async fn test_rejections_flushing(cx: &mut TestAppContext) {
1125 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1126
1127 ep_store.update(cx, |ep_store, _cx| {
1128 ep_store.reject_prediction(
1129 EditPredictionId("test-1".into()),
1130 EditPredictionRejectReason::Discarded,
1131 false,
1132 );
1133 ep_store.reject_prediction(
1134 EditPredictionId("test-2".into()),
1135 EditPredictionRejectReason::Canceled,
1136 true,
1137 );
1138 });
1139
1140 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1141 cx.run_until_parked();
1142
1143 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1144 respond_tx.send(()).unwrap();
1145
1146 // batched
1147 assert_eq!(reject_request.rejections.len(), 2);
1148 assert_eq!(
1149 reject_request.rejections[0],
1150 EditPredictionRejection {
1151 request_id: "test-1".to_string(),
1152 reason: EditPredictionRejectReason::Discarded,
1153 was_shown: false
1154 }
1155 );
1156 assert_eq!(
1157 reject_request.rejections[1],
1158 EditPredictionRejection {
1159 request_id: "test-2".to_string(),
1160 reason: EditPredictionRejectReason::Canceled,
1161 was_shown: true
1162 }
1163 );
1164
1165 // Reaching batch size limit sends without debounce
1166 ep_store.update(cx, |ep_store, _cx| {
1167 for i in 0..70 {
1168 ep_store.reject_prediction(
1169 EditPredictionId(format!("batch-{}", i).into()),
1170 EditPredictionRejectReason::Discarded,
1171 false,
1172 );
1173 }
1174 });
1175
1176 // First MAX/2 items are sent immediately
1177 cx.run_until_parked();
1178 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1179 respond_tx.send(()).unwrap();
1180
1181 assert_eq!(reject_request.rejections.len(), 50);
1182 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1183 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1184
1185 // Remaining items are debounced with the next batch
1186 cx.executor().advance_clock(Duration::from_secs(15));
1187 cx.run_until_parked();
1188
1189 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1190 respond_tx.send(()).unwrap();
1191
1192 assert_eq!(reject_request.rejections.len(), 20);
1193 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1194 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1195
1196 // Request failure
1197 ep_store.update(cx, |ep_store, _cx| {
1198 ep_store.reject_prediction(
1199 EditPredictionId("retry-1".into()),
1200 EditPredictionRejectReason::Discarded,
1201 false,
1202 );
1203 });
1204
1205 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1206 cx.run_until_parked();
1207
1208 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1209 assert_eq!(reject_request.rejections.len(), 1);
1210 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1211 // Simulate failure
1212 drop(_respond_tx);
1213
1214 // Add another rejection
1215 ep_store.update(cx, |ep_store, _cx| {
1216 ep_store.reject_prediction(
1217 EditPredictionId("retry-2".into()),
1218 EditPredictionRejectReason::Discarded,
1219 false,
1220 );
1221 });
1222
1223 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1224 cx.run_until_parked();
1225
1226 // Retry should include both the failed item and the new one
1227 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1228 respond_tx.send(()).unwrap();
1229
1230 assert_eq!(reject_request.rejections.len(), 2);
1231 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1232 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1233}
1234
1235// Skipped until we start including diagnostics in prompt
1236// #[gpui::test]
1237// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1238// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1239// let fs = FakeFs::new(cx.executor());
1240// fs.insert_tree(
1241// "/root",
1242// json!({
1243// "foo.md": "Hello!\nBye"
1244// }),
1245// )
1246// .await;
1247// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1248
1249// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1250// let diagnostic = lsp::Diagnostic {
1251// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1252// severity: Some(lsp::DiagnosticSeverity::ERROR),
1253// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1254// ..Default::default()
1255// };
1256
1257// project.update(cx, |project, cx| {
1258// project.lsp_store().update(cx, |lsp_store, cx| {
1259// // Create some diagnostics
1260// lsp_store
1261// .update_diagnostics(
1262// LanguageServerId(0),
1263// lsp::PublishDiagnosticsParams {
1264// uri: path_to_buffer_uri.clone(),
1265// diagnostics: vec![diagnostic],
1266// version: None,
1267// },
1268// None,
1269// language::DiagnosticSourceKind::Pushed,
1270// &[],
1271// cx,
1272// )
1273// .unwrap();
1274// });
1275// });
1276
1277// let buffer = project
1278// .update(cx, |project, cx| {
1279// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1280// project.open_buffer(path, cx)
1281// })
1282// .await
1283// .unwrap();
1284
1285// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1286// let position = snapshot.anchor_before(language::Point::new(0, 0));
1287
1288// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1289// ep_store.request_prediction(&project, &buffer, position, cx)
1290// });
1291
1292// let (request, _respond_tx) = req_rx.next().await.unwrap();
1293
1294// assert_eq!(request.diagnostic_groups.len(), 1);
1295// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1296// .unwrap();
1297// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1298// assert_eq!(
1299// value,
1300// json!({
1301// "entries": [{
1302// "range": {
1303// "start": 8,
1304// "end": 10
1305// },
1306// "diagnostic": {
1307// "source": null,
1308// "code": null,
1309// "code_description": null,
1310// "severity": 1,
1311// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1312// "markdown": null,
1313// "group_id": 0,
1314// "is_primary": true,
1315// "is_disk_based": false,
1316// "is_unnecessary": false,
1317// "source_kind": "Pushed",
1318// "data": null,
1319// "underline": true
1320// }
1321// }],
1322// "primary_ix": 0
1323// })
1324// );
1325// }
1326
1327// Generate a model response that would apply the given diff to the active file.
1328fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1329 let excerpt =
1330 request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1331 let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1332
1333 PredictEditsV3Response {
1334 request_id: Uuid::new_v4().to_string(),
1335 output: new_excerpt,
1336 }
1337}
1338
1339fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1340 zeta_prompt::format_zeta_prompt(&request.input, request.prompt_version)
1341}
1342
1343struct RequestChannels {
1344 predict: mpsc::UnboundedReceiver<(
1345 PredictEditsV3Request,
1346 oneshot::Sender<PredictEditsV3Response>,
1347 )>,
1348 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1349}
1350
1351fn init_test_with_fake_client(
1352 cx: &mut TestAppContext,
1353) -> (Entity<EditPredictionStore>, RequestChannels) {
1354 cx.update(move |cx| {
1355 let settings_store = SettingsStore::test(cx);
1356 cx.set_global(settings_store);
1357 zlog::init_test();
1358
1359 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1360 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1361
1362 let http_client = FakeHttpClient::create({
1363 move |req| {
1364 let uri = req.uri().path().to_string();
1365 let mut body = req.into_body();
1366 let predict_req_tx = predict_req_tx.clone();
1367 let reject_req_tx = reject_req_tx.clone();
1368 async move {
1369 let resp = match uri.as_str() {
1370 "/client/llm_tokens" => serde_json::to_string(&json!({
1371 "token": "test"
1372 }))
1373 .unwrap(),
1374 "/predict_edits/v3" => {
1375 let mut buf = Vec::new();
1376 body.read_to_end(&mut buf).await.ok();
1377 let req = serde_json::from_slice(&buf).unwrap();
1378
1379 let (res_tx, res_rx) = oneshot::channel();
1380 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1381 serde_json::to_string(&res_rx.await?).unwrap()
1382 }
1383 "/predict_edits/reject" => {
1384 let mut buf = Vec::new();
1385 body.read_to_end(&mut buf).await.ok();
1386 let req = serde_json::from_slice(&buf).unwrap();
1387
1388 let (res_tx, res_rx) = oneshot::channel();
1389 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1390 serde_json::to_string(&res_rx.await?).unwrap()
1391 }
1392 _ => {
1393 panic!("Unexpected path: {}", uri)
1394 }
1395 };
1396
1397 Ok(Response::builder().body(resp.into()).unwrap())
1398 }
1399 }
1400 });
1401
1402 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1403 client.cloud_client().set_credentials(1, "test".into());
1404
1405 language_model::init(client.clone(), cx);
1406
1407 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1408 let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1409
1410 (
1411 ep_store,
1412 RequestChannels {
1413 predict: predict_req_rx,
1414 reject: reject_req_rx,
1415 },
1416 )
1417 })
1418}
1419
1420const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1421
1422#[gpui::test]
1423async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1424 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1425 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1426 to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1427 });
1428
1429 let edit_preview = cx
1430 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1431 .await;
1432
1433 let prediction = EditPrediction {
1434 edits,
1435 edit_preview,
1436 buffer: buffer.clone(),
1437 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1438 id: EditPredictionId("the-id".into()),
1439 inputs: ZetaPromptInput {
1440 events: Default::default(),
1441 related_files: Default::default(),
1442 cursor_path: Path::new("").into(),
1443 cursor_excerpt: "".into(),
1444 editable_range_in_excerpt: 0..0,
1445 cursor_offset_in_excerpt: 0,
1446 },
1447 buffer_snapshotted_at: Instant::now(),
1448 response_received_at: Instant::now(),
1449 };
1450
1451 cx.update(|cx| {
1452 assert_eq!(
1453 from_completion_edits(
1454 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1455 &buffer,
1456 cx
1457 ),
1458 vec![(2..5, "REM".into()), (9..11, "".into())]
1459 );
1460
1461 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1462 assert_eq!(
1463 from_completion_edits(
1464 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1465 &buffer,
1466 cx
1467 ),
1468 vec![(2..2, "REM".into()), (6..8, "".into())]
1469 );
1470
1471 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1472 assert_eq!(
1473 from_completion_edits(
1474 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1475 &buffer,
1476 cx
1477 ),
1478 vec![(2..5, "REM".into()), (9..11, "".into())]
1479 );
1480
1481 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1482 assert_eq!(
1483 from_completion_edits(
1484 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1485 &buffer,
1486 cx
1487 ),
1488 vec![(3..3, "EM".into()), (7..9, "".into())]
1489 );
1490
1491 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1492 assert_eq!(
1493 from_completion_edits(
1494 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1495 &buffer,
1496 cx
1497 ),
1498 vec![(4..4, "M".into()), (8..10, "".into())]
1499 );
1500
1501 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1502 assert_eq!(
1503 from_completion_edits(
1504 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1505 &buffer,
1506 cx
1507 ),
1508 vec![(9..11, "".into())]
1509 );
1510
1511 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1512 assert_eq!(
1513 from_completion_edits(
1514 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1515 &buffer,
1516 cx
1517 ),
1518 vec![(4..4, "M".into()), (8..10, "".into())]
1519 );
1520
1521 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1522 assert_eq!(
1523 from_completion_edits(
1524 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1525 &buffer,
1526 cx
1527 ),
1528 vec![(4..4, "M".into())]
1529 );
1530
1531 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1532 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1533 })
1534}
1535
1536#[gpui::test]
1537async fn test_clean_up_diff(cx: &mut TestAppContext) {
1538 init_test(cx);
1539
1540 assert_eq!(
1541 apply_edit_prediction(
1542 indoc! {"
1543 fn main() {
1544 let word_1 = \"lorem\";
1545 let range = word.len()..word.len();
1546 }
1547 "},
1548 indoc! {"
1549 <|editable_region_start|>
1550 fn main() {
1551 let word_1 = \"lorem\";
1552 let range = word_1.len()..word_1.len();
1553 }
1554
1555 <|editable_region_end|>
1556 "},
1557 cx,
1558 )
1559 .await,
1560 indoc! {"
1561 fn main() {
1562 let word_1 = \"lorem\";
1563 let range = word_1.len()..word_1.len();
1564 }
1565 "},
1566 );
1567
1568 assert_eq!(
1569 apply_edit_prediction(
1570 indoc! {"
1571 fn main() {
1572 let story = \"the quick\"
1573 }
1574 "},
1575 indoc! {"
1576 <|editable_region_start|>
1577 fn main() {
1578 let story = \"the quick brown fox jumps over the lazy dog\";
1579 }
1580
1581 <|editable_region_end|>
1582 "},
1583 cx,
1584 )
1585 .await,
1586 indoc! {"
1587 fn main() {
1588 let story = \"the quick brown fox jumps over the lazy dog\";
1589 }
1590 "},
1591 );
1592}
1593
1594#[gpui::test]
1595async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1596 init_test(cx);
1597
1598 let buffer_content = "lorem\n";
1599 let completion_response = indoc! {"
1600 ```animals.js
1601 <|start_of_file|>
1602 <|editable_region_start|>
1603 lorem
1604 ipsum
1605 <|editable_region_end|>
1606 ```"};
1607
1608 assert_eq!(
1609 apply_edit_prediction(buffer_content, completion_response, cx).await,
1610 "lorem\nipsum"
1611 );
1612}
1613
1614#[gpui::test]
1615async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1616 // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1617 // When the buffer ends without a trailing newline, but the model returns output
1618 // with a trailing newline, zeta2 should normalize both sides before diffing
1619 // so no spurious newline is inserted.
1620 let (ep_store, mut requests) = init_test_with_fake_client(cx);
1621 let fs = FakeFs::new(cx.executor());
1622
1623 // Single line buffer with no trailing newline
1624 fs.insert_tree(
1625 "/root",
1626 json!({
1627 "foo.txt": "hello"
1628 }),
1629 )
1630 .await;
1631 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1632
1633 let buffer = project
1634 .update(cx, |project, cx| {
1635 let path = project
1636 .find_project_path(path!("root/foo.txt"), cx)
1637 .unwrap();
1638 project.open_buffer(path, cx)
1639 })
1640 .await
1641 .unwrap();
1642
1643 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1644 let position = snapshot.anchor_before(language::Point::new(0, 5));
1645
1646 ep_store.update(cx, |ep_store, cx| {
1647 ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1648 });
1649
1650 let (_request, respond_tx) = requests.predict.next().await.unwrap();
1651
1652 // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1653 // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1654 let response = PredictEditsV3Response {
1655 request_id: Uuid::new_v4().to_string(),
1656 output: "hello world\n".to_string(),
1657 };
1658 respond_tx.send(response).unwrap();
1659
1660 cx.run_until_parked();
1661
1662 // The prediction should insert " world" without adding a newline
1663 ep_store.update(cx, |ep_store, cx| {
1664 let prediction = ep_store
1665 .prediction_at(&buffer, None, &project, cx)
1666 .expect("should have prediction");
1667 let edits: Vec<_> = prediction
1668 .edits
1669 .iter()
1670 .map(|(range, text)| {
1671 let snapshot = buffer.read(cx).snapshot();
1672 (range.to_offset(&snapshot), text.clone())
1673 })
1674 .collect();
1675 assert_eq!(edits, vec![(5..5, " world".into())]);
1676 });
1677}
1678
1679#[gpui::test]
1680async fn test_can_collect_data(cx: &mut TestAppContext) {
1681 init_test(cx);
1682
1683 let fs = project::FakeFs::new(cx.executor());
1684 fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1685 .await;
1686
1687 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1688 let buffer = project
1689 .update(cx, |project, cx| {
1690 project.open_local_buffer(path!("/project/src/main.rs"), cx)
1691 })
1692 .await
1693 .unwrap();
1694
1695 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1696 ep_store.update(cx, |ep_store, _cx| {
1697 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1698 });
1699
1700 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1701 assert_eq!(
1702 captured_request.lock().clone().unwrap().can_collect_data,
1703 true
1704 );
1705
1706 ep_store.update(cx, |ep_store, _cx| {
1707 ep_store.data_collection_choice = DataCollectionChoice::Disabled
1708 });
1709
1710 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1711 assert_eq!(
1712 captured_request.lock().clone().unwrap().can_collect_data,
1713 false
1714 );
1715}
1716
1717#[gpui::test]
1718async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1719 init_test(cx);
1720
1721 let fs = project::FakeFs::new(cx.executor());
1722 let project = Project::test(fs.clone(), [], cx).await;
1723
1724 let buffer = cx.new(|_cx| {
1725 Buffer::remote(
1726 language::BufferId::new(1).unwrap(),
1727 ReplicaId::new(1),
1728 language::Capability::ReadWrite,
1729 "fn main() {\n println!(\"Hello\");\n}",
1730 )
1731 });
1732
1733 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1734 ep_store.update(cx, |ep_store, _cx| {
1735 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1736 });
1737
1738 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1739 assert_eq!(
1740 captured_request.lock().clone().unwrap().can_collect_data,
1741 false
1742 );
1743}
1744
1745#[gpui::test]
1746async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1747 init_test(cx);
1748
1749 let fs = project::FakeFs::new(cx.executor());
1750 fs.insert_tree(
1751 path!("/project"),
1752 json!({
1753 "LICENSE": BSD_0_TXT,
1754 ".env": "SECRET_KEY=secret"
1755 }),
1756 )
1757 .await;
1758
1759 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1760 let buffer = project
1761 .update(cx, |project, cx| {
1762 project.open_local_buffer("/project/.env", cx)
1763 })
1764 .await
1765 .unwrap();
1766
1767 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1768 ep_store.update(cx, |ep_store, _cx| {
1769 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1770 });
1771
1772 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1773 assert_eq!(
1774 captured_request.lock().clone().unwrap().can_collect_data,
1775 false
1776 );
1777}
1778
1779#[gpui::test]
1780async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1781 init_test(cx);
1782
1783 let fs = project::FakeFs::new(cx.executor());
1784 let project = Project::test(fs.clone(), [], cx).await;
1785 let buffer = cx.new(|cx| Buffer::local("", cx));
1786
1787 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1788 ep_store.update(cx, |ep_store, _cx| {
1789 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1790 });
1791
1792 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1793 assert_eq!(
1794 captured_request.lock().clone().unwrap().can_collect_data,
1795 false
1796 );
1797}
1798
1799#[gpui::test]
1800async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1801 init_test(cx);
1802
1803 let fs = project::FakeFs::new(cx.executor());
1804 fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1805 .await;
1806
1807 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1808 let buffer = project
1809 .update(cx, |project, cx| {
1810 project.open_local_buffer("/project/main.rs", cx)
1811 })
1812 .await
1813 .unwrap();
1814
1815 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1816 ep_store.update(cx, |ep_store, _cx| {
1817 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1818 });
1819
1820 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1821 assert_eq!(
1822 captured_request.lock().clone().unwrap().can_collect_data,
1823 false
1824 );
1825}
1826
1827#[gpui::test]
1828async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1829 init_test(cx);
1830
1831 let fs = project::FakeFs::new(cx.executor());
1832 fs.insert_tree(
1833 path!("/open_source_worktree"),
1834 json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1835 )
1836 .await;
1837 fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1838 .await;
1839
1840 let project = Project::test(
1841 fs.clone(),
1842 [
1843 path!("/open_source_worktree").as_ref(),
1844 path!("/closed_source_worktree").as_ref(),
1845 ],
1846 cx,
1847 )
1848 .await;
1849 let buffer = project
1850 .update(cx, |project, cx| {
1851 project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1852 })
1853 .await
1854 .unwrap();
1855
1856 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1857 ep_store.update(cx, |ep_store, _cx| {
1858 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1859 });
1860
1861 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1862 assert_eq!(
1863 captured_request.lock().clone().unwrap().can_collect_data,
1864 true
1865 );
1866
1867 let closed_source_file = project
1868 .update(cx, |project, cx| {
1869 let worktree2 = project
1870 .worktree_for_root_name("closed_source_worktree", cx)
1871 .unwrap();
1872 worktree2.update(cx, |worktree2, cx| {
1873 worktree2.load_file(rel_path("main.rs"), cx)
1874 })
1875 })
1876 .await
1877 .unwrap()
1878 .file;
1879
1880 buffer.update(cx, |buffer, cx| {
1881 buffer.file_updated(closed_source_file, cx);
1882 });
1883
1884 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1885 assert_eq!(
1886 captured_request.lock().clone().unwrap().can_collect_data,
1887 false
1888 );
1889}
1890
1891#[gpui::test]
1892async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1893 init_test(cx);
1894
1895 let fs = project::FakeFs::new(cx.executor());
1896 fs.insert_tree(
1897 path!("/worktree1"),
1898 json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1899 )
1900 .await;
1901 fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1902 .await;
1903
1904 let project = Project::test(
1905 fs.clone(),
1906 [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1907 cx,
1908 )
1909 .await;
1910 let buffer = project
1911 .update(cx, |project, cx| {
1912 project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1913 })
1914 .await
1915 .unwrap();
1916 let private_buffer = project
1917 .update(cx, |project, cx| {
1918 project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1919 })
1920 .await
1921 .unwrap();
1922
1923 let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1924 ep_store.update(cx, |ep_store, _cx| {
1925 ep_store.data_collection_choice = DataCollectionChoice::Enabled
1926 });
1927
1928 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1929 assert_eq!(
1930 captured_request.lock().clone().unwrap().can_collect_data,
1931 true
1932 );
1933
1934 // this has a side effect of registering the buffer to watch for edits
1935 run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1936 assert_eq!(
1937 captured_request.lock().clone().unwrap().can_collect_data,
1938 false
1939 );
1940
1941 private_buffer.update(cx, |private_buffer, cx| {
1942 private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1943 });
1944
1945 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1946 assert_eq!(
1947 captured_request.lock().clone().unwrap().can_collect_data,
1948 false
1949 );
1950
1951 // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1952 // included
1953 buffer.update(cx, |buffer, cx| {
1954 buffer.edit(
1955 [(
1956 0..0,
1957 " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1958 )],
1959 None,
1960 cx,
1961 );
1962 });
1963
1964 run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1965 assert_eq!(
1966 captured_request.lock().clone().unwrap().can_collect_data,
1967 true
1968 );
1969}
1970
1971fn init_test(cx: &mut TestAppContext) {
1972 cx.update(|cx| {
1973 let settings_store = SettingsStore::test(cx);
1974 cx.set_global(settings_store);
1975 });
1976}
1977
1978async fn apply_edit_prediction(
1979 buffer_content: &str,
1980 completion_response: &str,
1981 cx: &mut TestAppContext,
1982) -> String {
1983 let fs = project::FakeFs::new(cx.executor());
1984 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1985 let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1986 let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1987 *response.lock() = completion_response.to_string();
1988 let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1989 buffer.update(cx, |buffer, cx| {
1990 buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1991 });
1992 buffer.read_with(cx, |buffer, _| buffer.text())
1993}
1994
1995async fn run_edit_prediction(
1996 buffer: &Entity<Buffer>,
1997 project: &Entity<Project>,
1998 ep_store: &Entity<EditPredictionStore>,
1999 cx: &mut TestAppContext,
2000) -> EditPrediction {
2001 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2002 ep_store.update(cx, |ep_store, cx| {
2003 ep_store.register_buffer(buffer, &project, cx)
2004 });
2005 cx.background_executor.run_until_parked();
2006 let prediction_task = ep_store.update(cx, |ep_store, cx| {
2007 ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2008 });
2009 prediction_task.await.unwrap().unwrap().prediction.unwrap()
2010}
2011
2012async fn make_test_ep_store(
2013 project: &Entity<Project>,
2014 cx: &mut TestAppContext,
2015) -> (
2016 Entity<EditPredictionStore>,
2017 Arc<Mutex<Option<PredictEditsBody>>>,
2018 Arc<Mutex<String>>,
2019) {
2020 let default_response = indoc! {"
2021 ```main.rs
2022 <|start_of_file|>
2023 <|editable_region_start|>
2024 hello world
2025 <|editable_region_end|>
2026 ```"
2027 };
2028 let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2029 let completion_response: Arc<Mutex<String>> =
2030 Arc::new(Mutex::new(default_response.to_string()));
2031 let http_client = FakeHttpClient::create({
2032 let captured_request = captured_request.clone();
2033 let completion_response = completion_response.clone();
2034 let mut next_request_id = 0;
2035 move |req| {
2036 let captured_request = captured_request.clone();
2037 let completion_response = completion_response.clone();
2038 async move {
2039 match (req.method(), req.uri().path()) {
2040 (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2041 .status(200)
2042 .body(
2043 serde_json::to_string(&CreateLlmTokenResponse {
2044 token: LlmToken("the-llm-token".to_string()),
2045 })
2046 .unwrap()
2047 .into(),
2048 )
2049 .unwrap()),
2050 (&Method::POST, "/predict_edits/v2") => {
2051 let mut request_body = String::new();
2052 req.into_body().read_to_string(&mut request_body).await?;
2053 *captured_request.lock() =
2054 Some(serde_json::from_str(&request_body).unwrap());
2055 next_request_id += 1;
2056 Ok(http_client::Response::builder()
2057 .status(200)
2058 .body(
2059 serde_json::to_string(&PredictEditsResponse {
2060 request_id: format!("request-{next_request_id}"),
2061 output_excerpt: completion_response.lock().clone(),
2062 })
2063 .unwrap()
2064 .into(),
2065 )
2066 .unwrap())
2067 }
2068 _ => Ok(http_client::Response::builder()
2069 .status(404)
2070 .body("Not Found".into())
2071 .unwrap()),
2072 }
2073 }
2074 }
2075 });
2076
2077 let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2078 cx.update(|cx| {
2079 RefreshLlmTokenListener::register(client.clone(), cx);
2080 });
2081 let _server = FakeServer::for_client(42, &client, cx).await;
2082
2083 let ep_store = cx.new(|cx| {
2084 let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2085 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2086
2087 let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2088 for worktree in worktrees {
2089 let worktree_id = worktree.read(cx).id();
2090 ep_store
2091 .get_or_init_project(project, cx)
2092 .license_detection_watchers
2093 .entry(worktree_id)
2094 .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2095 }
2096
2097 ep_store
2098 });
2099
2100 (ep_store, captured_request, completion_response)
2101}
2102
2103fn to_completion_edits(
2104 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2105 buffer: &Entity<Buffer>,
2106 cx: &App,
2107) -> Vec<(Range<Anchor>, Arc<str>)> {
2108 let buffer = buffer.read(cx);
2109 iterator
2110 .into_iter()
2111 .map(|(range, text)| {
2112 (
2113 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2114 text,
2115 )
2116 })
2117 .collect()
2118}
2119
2120fn from_completion_edits(
2121 editor_edits: &[(Range<Anchor>, Arc<str>)],
2122 buffer: &Entity<Buffer>,
2123 cx: &App,
2124) -> Vec<(Range<usize>, Arc<str>)> {
2125 let buffer = buffer.read(cx);
2126 editor_edits
2127 .iter()
2128 .map(|(range, text)| {
2129 (
2130 range.start.to_offset(buffer)..range.end.to_offset(buffer),
2131 text.clone(),
2132 )
2133 })
2134 .collect()
2135}
2136
2137#[gpui::test]
2138async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2139 init_test(cx);
2140
2141 let fs = FakeFs::new(cx.executor());
2142 fs.insert_tree(
2143 "/project",
2144 serde_json::json!({
2145 "main.rs": "fn main() {\n \n}\n"
2146 }),
2147 )
2148 .await;
2149
2150 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2151
2152 let http_client = FakeHttpClient::create(|_req| async move {
2153 Ok(gpui::http_client::Response::builder()
2154 .status(401)
2155 .body("Unauthorized".into())
2156 .unwrap())
2157 });
2158
2159 let client =
2160 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2161 cx.update(|cx| {
2162 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2163 });
2164
2165 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2166
2167 let buffer = project
2168 .update(cx, |project, cx| {
2169 let path = project
2170 .find_project_path(path!("/project/main.rs"), cx)
2171 .unwrap();
2172 project.open_buffer(path, cx)
2173 })
2174 .await
2175 .unwrap();
2176
2177 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2178 ep_store.update(cx, |ep_store, cx| {
2179 ep_store.register_buffer(&buffer, &project, cx)
2180 });
2181 cx.background_executor.run_until_parked();
2182
2183 let completion_task = ep_store.update(cx, |ep_store, cx| {
2184 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2185 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2186 });
2187
2188 let result = completion_task.await;
2189 assert!(
2190 result.is_err(),
2191 "Without authentication and without custom URL, prediction should fail"
2192 );
2193}
2194
2195#[gpui::test]
2196async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2197 init_test(cx);
2198
2199 let fs = FakeFs::new(cx.executor());
2200 fs.insert_tree(
2201 "/project",
2202 serde_json::json!({
2203 "main.rs": "fn main() {\n \n}\n"
2204 }),
2205 )
2206 .await;
2207
2208 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2209
2210 let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2211 let predict_called_clone = predict_called.clone();
2212
2213 let http_client = FakeHttpClient::create({
2214 move |req| {
2215 let uri = req.uri().path().to_string();
2216 let predict_called = predict_called_clone.clone();
2217 async move {
2218 if uri.contains("predict") {
2219 predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2220 Ok(gpui::http_client::Response::builder()
2221 .body(
2222 serde_json::to_string(&open_ai::Response {
2223 id: "test-123".to_string(),
2224 object: "chat.completion".to_string(),
2225 created: 0,
2226 model: "test".to_string(),
2227 usage: open_ai::Usage {
2228 prompt_tokens: 0,
2229 completion_tokens: 0,
2230 total_tokens: 0,
2231 },
2232 choices: vec![open_ai::Choice {
2233 index: 0,
2234 message: open_ai::RequestMessage::Assistant {
2235 content: Some(open_ai::MessageContent::Plain(
2236 indoc! {"
2237 ```main.rs
2238 <|start_of_file|>
2239 <|editable_region_start|>
2240 fn main() {
2241 println!(\"Hello, world!\");
2242 }
2243 <|editable_region_end|>
2244 ```
2245 "}
2246 .to_string(),
2247 )),
2248 tool_calls: vec![],
2249 },
2250 finish_reason: Some("stop".to_string()),
2251 }],
2252 })
2253 .unwrap()
2254 .into(),
2255 )
2256 .unwrap())
2257 } else {
2258 Ok(gpui::http_client::Response::builder()
2259 .status(401)
2260 .body("Unauthorized".into())
2261 .unwrap())
2262 }
2263 }
2264 }
2265 });
2266
2267 let client =
2268 cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2269 cx.update(|cx| {
2270 language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2271 });
2272
2273 let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2274
2275 let buffer = project
2276 .update(cx, |project, cx| {
2277 let path = project
2278 .find_project_path(path!("/project/main.rs"), cx)
2279 .unwrap();
2280 project.open_buffer(path, cx)
2281 })
2282 .await
2283 .unwrap();
2284
2285 let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2286 ep_store.update(cx, |ep_store, cx| {
2287 ep_store.register_buffer(&buffer, &project, cx)
2288 });
2289 cx.background_executor.run_until_parked();
2290
2291 let completion_task = ep_store.update(cx, |ep_store, cx| {
2292 ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2293 ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2294 ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2295 });
2296
2297 let _ = completion_task.await;
2298
2299 assert!(
2300 predict_called.load(std::sync::atomic::Ordering::SeqCst),
2301 "With custom URL, predict endpoint should be called even without authentication"
2302 );
2303}
2304
2305#[gpui::test]
2306fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2307 let buffer = cx.new(|cx| {
2308 Buffer::local(
2309 indoc! {"
2310 zero
2311 one
2312 two
2313 three
2314 four
2315 five
2316 six
2317 seven
2318 eight
2319 nine
2320 ten
2321 eleven
2322 twelve
2323 thirteen
2324 fourteen
2325 fifteen
2326 sixteen
2327 seventeen
2328 eighteen
2329 nineteen
2330 twenty
2331 twenty-one
2332 twenty-two
2333 twenty-three
2334 twenty-four
2335 "},
2336 cx,
2337 )
2338 });
2339
2340 let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2341
2342 buffer.update(cx, |buffer, cx| {
2343 let point = Point::new(12, 0);
2344 buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2345 let point = Point::new(8, 0);
2346 buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2347 });
2348
2349 let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2350
2351 let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2352
2353 assert_eq!(
2354 diff,
2355 indoc! {"
2356 @@ -6,10 +6,12 @@
2357 five
2358 six
2359 seven
2360 +FIRST INSERTION
2361 eight
2362 nine
2363 ten
2364 eleven
2365 +SECOND INSERTION
2366 twelve
2367 thirteen
2368 fourteen
2369 "}
2370 );
2371}
2372
2373#[ctor::ctor]
2374fn init_logger() {
2375 zlog::init_test();
2376}