zeta_tests.rs

  1use client::test::FakeServer;
  2use clock::{FakeSystemClock, ReplicaId};
  3use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
  4use cloud_llm_client::{PredictEditsBody, PredictEditsResponse};
  5use gpui::TestAppContext;
  6use http_client::FakeHttpClient;
  7use indoc::indoc;
  8use language::Point;
  9use parking_lot::Mutex;
 10use serde_json::json;
 11use settings::SettingsStore;
 12use util::{path, rel_path::rel_path};
 13
 14use crate::zeta1::MAX_EVENT_TOKENS;
 15
 16use super::*;
 17
 18const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
 19
 20#[gpui::test]
 21async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
 22    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
 23    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
 24        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
 25    });
 26
 27    let edit_preview = cx
 28        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
 29        .await;
 30
 31    let completion = EditPrediction {
 32        edits,
 33        edit_preview,
 34        buffer: buffer.clone(),
 35        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
 36        id: EditPredictionId("the-id".into()),
 37        inputs: EditPredictionInputs {
 38            events: Default::default(),
 39            included_files: Default::default(),
 40            cursor_point: cloud_llm_client::predict_edits_v3::Point {
 41                line: Line(0),
 42                column: 0,
 43            },
 44            cursor_path: Path::new("").into(),
 45        },
 46        buffer_snapshotted_at: Instant::now(),
 47        response_received_at: Instant::now(),
 48    };
 49
 50    cx.update(|cx| {
 51        assert_eq!(
 52            from_completion_edits(
 53                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
 54                &buffer,
 55                cx
 56            ),
 57            vec![(2..5, "REM".into()), (9..11, "".into())]
 58        );
 59
 60        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
 61        assert_eq!(
 62            from_completion_edits(
 63                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
 64                &buffer,
 65                cx
 66            ),
 67            vec![(2..2, "REM".into()), (6..8, "".into())]
 68        );
 69
 70        buffer.update(cx, |buffer, cx| buffer.undo(cx));
 71        assert_eq!(
 72            from_completion_edits(
 73                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
 74                &buffer,
 75                cx
 76            ),
 77            vec![(2..5, "REM".into()), (9..11, "".into())]
 78        );
 79
 80        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
 81        assert_eq!(
 82            from_completion_edits(
 83                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
 84                &buffer,
 85                cx
 86            ),
 87            vec![(3..3, "EM".into()), (7..9, "".into())]
 88        );
 89
 90        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
 91        assert_eq!(
 92            from_completion_edits(
 93                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
 94                &buffer,
 95                cx
 96            ),
 97            vec![(4..4, "M".into()), (8..10, "".into())]
 98        );
 99
100        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
101        assert_eq!(
102            from_completion_edits(
103                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
104                &buffer,
105                cx
106            ),
107            vec![(9..11, "".into())]
108        );
109
110        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
111        assert_eq!(
112            from_completion_edits(
113                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
114                &buffer,
115                cx
116            ),
117            vec![(4..4, "M".into()), (8..10, "".into())]
118        );
119
120        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
121        assert_eq!(
122            from_completion_edits(
123                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
124                &buffer,
125                cx
126            ),
127            vec![(4..4, "M".into())]
128        );
129
130        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
131        assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
132    })
133}
134
135#[gpui::test]
136async fn test_clean_up_diff(cx: &mut TestAppContext) {
137    init_test(cx);
138
139    assert_eq!(
140        apply_edit_prediction(
141            indoc! {"
142                    fn main() {
143                        let word_1 = \"lorem\";
144                        let range = word.len()..word.len();
145                    }
146                "},
147            indoc! {"
148                    <|editable_region_start|>
149                    fn main() {
150                        let word_1 = \"lorem\";
151                        let range = word_1.len()..word_1.len();
152                    }
153
154                    <|editable_region_end|>
155                "},
156            cx,
157        )
158        .await,
159        indoc! {"
160                fn main() {
161                    let word_1 = \"lorem\";
162                    let range = word_1.len()..word_1.len();
163                }
164            "},
165    );
166
167    assert_eq!(
168        apply_edit_prediction(
169            indoc! {"
170                    fn main() {
171                        let story = \"the quick\"
172                    }
173                "},
174            indoc! {"
175                    <|editable_region_start|>
176                    fn main() {
177                        let story = \"the quick brown fox jumps over the lazy dog\";
178                    }
179
180                    <|editable_region_end|>
181                "},
182            cx,
183        )
184        .await,
185        indoc! {"
186                fn main() {
187                    let story = \"the quick brown fox jumps over the lazy dog\";
188                }
189            "},
190    );
191}
192
193#[gpui::test]
194async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
195    init_test(cx);
196
197    let buffer_content = "lorem\n";
198    let completion_response = indoc! {"
199            ```animals.js
200            <|start_of_file|>
201            <|editable_region_start|>
202            lorem
203            ipsum
204            <|editable_region_end|>
205            ```"};
206
207    assert_eq!(
208        apply_edit_prediction(buffer_content, completion_response, cx).await,
209        "lorem\nipsum"
210    );
211}
212
213#[gpui::test]
214async fn test_can_collect_data(cx: &mut TestAppContext) {
215    init_test(cx);
216
217    let fs = project::FakeFs::new(cx.executor());
218    fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
219        .await;
220
221    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
222    let buffer = project
223        .update(cx, |project, cx| {
224            project.open_local_buffer(path!("/project/src/main.rs"), cx)
225        })
226        .await
227        .unwrap();
228
229    let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
230    zeta.update(cx, |zeta, _cx| {
231        zeta.data_collection_choice = DataCollectionChoice::Enabled
232    });
233
234    run_edit_prediction(&buffer, &project, &zeta, cx).await;
235    assert_eq!(
236        captured_request.lock().clone().unwrap().can_collect_data,
237        true
238    );
239
240    zeta.update(cx, |zeta, _cx| {
241        zeta.data_collection_choice = DataCollectionChoice::Disabled
242    });
243
244    run_edit_prediction(&buffer, &project, &zeta, cx).await;
245    assert_eq!(
246        captured_request.lock().clone().unwrap().can_collect_data,
247        false
248    );
249}
250
251#[gpui::test]
252async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
253    init_test(cx);
254
255    let fs = project::FakeFs::new(cx.executor());
256    let project = Project::test(fs.clone(), [], cx).await;
257
258    let buffer = cx.new(|_cx| {
259        Buffer::remote(
260            language::BufferId::new(1).unwrap(),
261            ReplicaId::new(1),
262            language::Capability::ReadWrite,
263            "fn main() {\n    println!(\"Hello\");\n}",
264        )
265    });
266
267    let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
268    zeta.update(cx, |zeta, _cx| {
269        zeta.data_collection_choice = DataCollectionChoice::Enabled
270    });
271
272    run_edit_prediction(&buffer, &project, &zeta, cx).await;
273    assert_eq!(
274        captured_request.lock().clone().unwrap().can_collect_data,
275        false
276    );
277}
278
279#[gpui::test]
280async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
281    init_test(cx);
282
283    let fs = project::FakeFs::new(cx.executor());
284    fs.insert_tree(
285        path!("/project"),
286        json!({
287            "LICENSE": BSD_0_TXT,
288            ".env": "SECRET_KEY=secret"
289        }),
290    )
291    .await;
292
293    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
294    let buffer = project
295        .update(cx, |project, cx| {
296            project.open_local_buffer("/project/.env", cx)
297        })
298        .await
299        .unwrap();
300
301    let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
302    zeta.update(cx, |zeta, _cx| {
303        zeta.data_collection_choice = DataCollectionChoice::Enabled
304    });
305
306    run_edit_prediction(&buffer, &project, &zeta, cx).await;
307    assert_eq!(
308        captured_request.lock().clone().unwrap().can_collect_data,
309        false
310    );
311}
312
313#[gpui::test]
314async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
315    init_test(cx);
316
317    let fs = project::FakeFs::new(cx.executor());
318    let project = Project::test(fs.clone(), [], cx).await;
319    let buffer = cx.new(|cx| Buffer::local("", cx));
320
321    let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
322    zeta.update(cx, |zeta, _cx| {
323        zeta.data_collection_choice = DataCollectionChoice::Enabled
324    });
325
326    run_edit_prediction(&buffer, &project, &zeta, cx).await;
327    assert_eq!(
328        captured_request.lock().clone().unwrap().can_collect_data,
329        false
330    );
331}
332
333#[gpui::test]
334async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
335    init_test(cx);
336
337    let fs = project::FakeFs::new(cx.executor());
338    fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
339        .await;
340
341    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
342    let buffer = project
343        .update(cx, |project, cx| {
344            project.open_local_buffer("/project/main.rs", cx)
345        })
346        .await
347        .unwrap();
348
349    let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
350    zeta.update(cx, |zeta, _cx| {
351        zeta.data_collection_choice = DataCollectionChoice::Enabled
352    });
353
354    run_edit_prediction(&buffer, &project, &zeta, cx).await;
355    assert_eq!(
356        captured_request.lock().clone().unwrap().can_collect_data,
357        false
358    );
359}
360
361#[gpui::test]
362async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
363    init_test(cx);
364
365    let fs = project::FakeFs::new(cx.executor());
366    fs.insert_tree(
367        path!("/open_source_worktree"),
368        json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
369    )
370    .await;
371    fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
372        .await;
373
374    let project = Project::test(
375        fs.clone(),
376        [
377            path!("/open_source_worktree").as_ref(),
378            path!("/closed_source_worktree").as_ref(),
379        ],
380        cx,
381    )
382    .await;
383    let buffer = project
384        .update(cx, |project, cx| {
385            project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
386        })
387        .await
388        .unwrap();
389
390    let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
391    zeta.update(cx, |zeta, _cx| {
392        zeta.data_collection_choice = DataCollectionChoice::Enabled
393    });
394
395    run_edit_prediction(&buffer, &project, &zeta, cx).await;
396    assert_eq!(
397        captured_request.lock().clone().unwrap().can_collect_data,
398        true
399    );
400
401    let closed_source_file = project
402        .update(cx, |project, cx| {
403            let worktree2 = project
404                .worktree_for_root_name("closed_source_worktree", cx)
405                .unwrap();
406            worktree2.update(cx, |worktree2, cx| {
407                worktree2.load_file(rel_path("main.rs"), cx)
408            })
409        })
410        .await
411        .unwrap()
412        .file;
413
414    buffer.update(cx, |buffer, cx| {
415        buffer.file_updated(closed_source_file, cx);
416    });
417
418    run_edit_prediction(&buffer, &project, &zeta, cx).await;
419    assert_eq!(
420        captured_request.lock().clone().unwrap().can_collect_data,
421        false
422    );
423}
424
425#[gpui::test]
426async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
427    init_test(cx);
428
429    let fs = project::FakeFs::new(cx.executor());
430    fs.insert_tree(
431        path!("/worktree1"),
432        json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
433    )
434    .await;
435    fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
436        .await;
437
438    let project = Project::test(
439        fs.clone(),
440        [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
441        cx,
442    )
443    .await;
444    let buffer = project
445        .update(cx, |project, cx| {
446            project.open_local_buffer(path!("/worktree1/main.rs"), cx)
447        })
448        .await
449        .unwrap();
450    let private_buffer = project
451        .update(cx, |project, cx| {
452            project.open_local_buffer(path!("/worktree2/file.rs"), cx)
453        })
454        .await
455        .unwrap();
456
457    let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
458    zeta.update(cx, |zeta, _cx| {
459        zeta.data_collection_choice = DataCollectionChoice::Enabled
460    });
461
462    run_edit_prediction(&buffer, &project, &zeta, cx).await;
463    assert_eq!(
464        captured_request.lock().clone().unwrap().can_collect_data,
465        true
466    );
467
468    // this has a side effect of registering the buffer to watch for edits
469    run_edit_prediction(&private_buffer, &project, &zeta, cx).await;
470    assert_eq!(
471        captured_request.lock().clone().unwrap().can_collect_data,
472        false
473    );
474
475    private_buffer.update(cx, |private_buffer, cx| {
476        private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
477    });
478
479    run_edit_prediction(&buffer, &project, &zeta, cx).await;
480    assert_eq!(
481        captured_request.lock().clone().unwrap().can_collect_data,
482        false
483    );
484
485    // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
486    // included
487    buffer.update(cx, |buffer, cx| {
488        buffer.edit(
489            [(
490                0..0,
491                " ".repeat(MAX_EVENT_TOKENS * zeta1::BYTES_PER_TOKEN_GUESS),
492            )],
493            None,
494            cx,
495        );
496    });
497
498    run_edit_prediction(&buffer, &project, &zeta, cx).await;
499    assert_eq!(
500        captured_request.lock().clone().unwrap().can_collect_data,
501        true
502    );
503}
504
505fn init_test(cx: &mut TestAppContext) {
506    cx.update(|cx| {
507        let settings_store = SettingsStore::test(cx);
508        cx.set_global(settings_store);
509    });
510}
511
512async fn apply_edit_prediction(
513    buffer_content: &str,
514    completion_response: &str,
515    cx: &mut TestAppContext,
516) -> String {
517    let fs = project::FakeFs::new(cx.executor());
518    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
519    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
520    let (zeta, _, response) = make_test_zeta(&project, cx).await;
521    *response.lock() = completion_response.to_string();
522    let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await;
523    buffer.update(cx, |buffer, cx| {
524        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
525    });
526    buffer.read_with(cx, |buffer, _| buffer.text())
527}
528
529async fn run_edit_prediction(
530    buffer: &Entity<Buffer>,
531    project: &Entity<Project>,
532    zeta: &Entity<Zeta>,
533    cx: &mut TestAppContext,
534) -> EditPrediction {
535    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
536    zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx));
537    cx.background_executor.run_until_parked();
538    let prediction_task = zeta.update(cx, |zeta, cx| {
539        zeta.request_prediction(&project, buffer, cursor, Default::default(), cx)
540    });
541    prediction_task.await.unwrap().unwrap().prediction.unwrap()
542}
543
544async fn make_test_zeta(
545    project: &Entity<Project>,
546    cx: &mut TestAppContext,
547) -> (
548    Entity<Zeta>,
549    Arc<Mutex<Option<PredictEditsBody>>>,
550    Arc<Mutex<String>>,
551) {
552    let default_response = indoc! {"
553            ```main.rs
554            <|start_of_file|>
555            <|editable_region_start|>
556            hello world
557            <|editable_region_end|>
558            ```"
559    };
560    let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
561    let completion_response: Arc<Mutex<String>> =
562        Arc::new(Mutex::new(default_response.to_string()));
563    let http_client = FakeHttpClient::create({
564        let captured_request = captured_request.clone();
565        let completion_response = completion_response.clone();
566        let mut next_request_id = 0;
567        move |req| {
568            let captured_request = captured_request.clone();
569            let completion_response = completion_response.clone();
570            async move {
571                match (req.method(), req.uri().path()) {
572                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
573                        .status(200)
574                        .body(
575                            serde_json::to_string(&CreateLlmTokenResponse {
576                                token: LlmToken("the-llm-token".to_string()),
577                            })
578                            .unwrap()
579                            .into(),
580                        )
581                        .unwrap()),
582                    (&Method::POST, "/predict_edits/v2") => {
583                        let mut request_body = String::new();
584                        req.into_body().read_to_string(&mut request_body).await?;
585                        *captured_request.lock() =
586                            Some(serde_json::from_str(&request_body).unwrap());
587                        next_request_id += 1;
588                        Ok(http_client::Response::builder()
589                            .status(200)
590                            .body(
591                                serde_json::to_string(&PredictEditsResponse {
592                                    request_id: format!("request-{next_request_id}"),
593                                    output_excerpt: completion_response.lock().clone(),
594                                })
595                                .unwrap()
596                                .into(),
597                            )
598                            .unwrap())
599                    }
600                    _ => Ok(http_client::Response::builder()
601                        .status(404)
602                        .body("Not Found".into())
603                        .unwrap()),
604                }
605            }
606        }
607    });
608
609    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
610    cx.update(|cx| {
611        RefreshLlmTokenListener::register(client.clone(), cx);
612    });
613    let _server = FakeServer::for_client(42, &client, cx).await;
614
615    let zeta = cx.new(|cx| {
616        let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx);
617        zeta.set_edit_prediction_model(ZetaEditPredictionModel::Zeta1);
618
619        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
620        for worktree in worktrees {
621            let worktree_id = worktree.read(cx).id();
622            zeta.get_or_init_zeta_project(project, cx)
623                .license_detection_watchers
624                .entry(worktree_id)
625                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
626        }
627
628        zeta
629    });
630
631    (zeta, captured_request, completion_response)
632}
633
634fn to_completion_edits(
635    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
636    buffer: &Entity<Buffer>,
637    cx: &App,
638) -> Vec<(Range<Anchor>, Arc<str>)> {
639    let buffer = buffer.read(cx);
640    iterator
641        .into_iter()
642        .map(|(range, text)| {
643            (
644                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
645                text,
646            )
647        })
648        .collect()
649}
650
651fn from_completion_edits(
652    editor_edits: &[(Range<Anchor>, Arc<str>)],
653    buffer: &Entity<Buffer>,
654    cx: &App,
655) -> Vec<(Range<usize>, Arc<str>)> {
656    let buffer = buffer.read(cx);
657    editor_edits
658        .iter()
659        .map(|(range, text)| {
660            (
661                range.start.to_offset(buffer)..range.end.to_offset(buffer),
662                text.clone(),
663            )
664        })
665        .collect()
666}
667
668#[ctor::ctor]
669fn init_logger() {
670    zlog::init_test();
671}