edit_prediction_tests.rs

   1use super::*;
   2use crate::zeta1::MAX_EVENT_TOKENS;
   3use client::{UserStore, test::FakeServer};
   4use clock::{FakeSystemClock, ReplicaId};
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
   8    RejectEditPredictionsBody,
   9};
  10use edit_prediction_context::Line;
  11use futures::{
  12    AsyncReadExt, StreamExt,
  13    channel::{mpsc, oneshot},
  14};
  15use gpui::{
  16    Entity, TestAppContext,
  17    http_client::{FakeHttpClient, Response},
  18};
  19use indoc::indoc;
  20use language::{Point, ToOffset as _};
  21use lsp::LanguageServerId;
  22use open_ai::Usage;
  23use parking_lot::Mutex;
  24use pretty_assertions::{assert_eq, assert_matches};
  25use project::{FakeFs, Project};
  26use serde_json::json;
  27use settings::SettingsStore;
  28use std::{path::Path, sync::Arc, time::Duration};
  29use util::{path, rel_path::rel_path};
  30use uuid::Uuid;
  31
  32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  33
  34#[gpui::test]
  35async fn test_current_state(cx: &mut TestAppContext) {
  36    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  37    let fs = FakeFs::new(cx.executor());
  38    fs.insert_tree(
  39        "/root",
  40        json!({
  41            "1.txt": "Hello!\nHow\nBye\n",
  42            "2.txt": "Hola!\nComo\nAdios\n"
  43        }),
  44    )
  45    .await;
  46    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  47
  48    ep_store.update(cx, |ep_store, cx| {
  49        ep_store.register_project(&project, cx);
  50    });
  51
  52    let buffer1 = project
  53        .update(cx, |project, cx| {
  54            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  55            project.set_active_path(Some(path.clone()), cx);
  56            project.open_buffer(path, cx)
  57        })
  58        .await
  59        .unwrap();
  60    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  61    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  62
  63    // Prediction for current file
  64
  65    ep_store.update(cx, |ep_store, cx| {
  66        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  67    });
  68    let (_request, respond_tx) = requests.predict.next().await.unwrap();
  69
  70    respond_tx
  71        .send(model_response(indoc! {r"
  72            --- a/root/1.txt
  73            +++ b/root/1.txt
  74            @@ ... @@
  75             Hello!
  76            -How
  77            +How are you?
  78             Bye
  79        "}))
  80        .unwrap();
  81
  82    cx.run_until_parked();
  83
  84    ep_store.read_with(cx, |ep_store, cx| {
  85        let prediction = ep_store
  86            .current_prediction_for_buffer(&buffer1, &project, cx)
  87            .unwrap();
  88        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  89    });
  90
  91    ep_store.update(cx, |ep_store, _cx| {
  92        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
  93    });
  94
  95    // Prediction for diagnostic in another file
  96
  97    let diagnostic = lsp::Diagnostic {
  98        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
  99        severity: Some(lsp::DiagnosticSeverity::ERROR),
 100        message: "Sentence is incomplete".to_string(),
 101        ..Default::default()
 102    };
 103
 104    project.update(cx, |project, cx| {
 105        project.lsp_store().update(cx, |lsp_store, cx| {
 106            lsp_store
 107                .update_diagnostics(
 108                    LanguageServerId(0),
 109                    lsp::PublishDiagnosticsParams {
 110                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 111                        diagnostics: vec![diagnostic],
 112                        version: None,
 113                    },
 114                    None,
 115                    language::DiagnosticSourceKind::Pushed,
 116                    &[],
 117                    cx,
 118                )
 119                .unwrap();
 120        });
 121    });
 122
 123    let (_request, respond_tx) = requests.predict.next().await.unwrap();
 124    respond_tx
 125        .send(model_response(indoc! {r#"
 126            --- a/root/2.txt
 127            +++ b/root/2.txt
 128             Hola!
 129            -Como
 130            +Como estas?
 131             Adios
 132        "#}))
 133        .unwrap();
 134    cx.run_until_parked();
 135
 136    ep_store.read_with(cx, |ep_store, cx| {
 137        let prediction = ep_store
 138            .current_prediction_for_buffer(&buffer1, &project, cx)
 139            .unwrap();
 140        assert_matches!(
 141            prediction,
 142            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 143        );
 144    });
 145
 146    let buffer2 = project
 147        .update(cx, |project, cx| {
 148            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 149            project.open_buffer(path, cx)
 150        })
 151        .await
 152        .unwrap();
 153
 154    ep_store.read_with(cx, |ep_store, cx| {
 155        let prediction = ep_store
 156            .current_prediction_for_buffer(&buffer2, &project, cx)
 157            .unwrap();
 158        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 159    });
 160}
 161
 162#[gpui::test]
 163async fn test_simple_request(cx: &mut TestAppContext) {
 164    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 165    let fs = FakeFs::new(cx.executor());
 166    fs.insert_tree(
 167        "/root",
 168        json!({
 169            "foo.md":  "Hello!\nHow\nBye\n"
 170        }),
 171    )
 172    .await;
 173    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 174
 175    let buffer = project
 176        .update(cx, |project, cx| {
 177            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 178            project.open_buffer(path, cx)
 179        })
 180        .await
 181        .unwrap();
 182    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 183    let position = snapshot.anchor_before(language::Point::new(1, 3));
 184
 185    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 186        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 187    });
 188
 189    let (_, respond_tx) = requests.predict.next().await.unwrap();
 190
 191    // TODO Put back when we have a structured request again
 192    // assert_eq!(
 193    //     request.excerpt_path.as_ref(),
 194    //     Path::new(path!("root/foo.md"))
 195    // );
 196    // assert_eq!(
 197    //     request.cursor_point,
 198    //     Point {
 199    //         line: Line(1),
 200    //         column: 3
 201    //     }
 202    // );
 203
 204    respond_tx
 205        .send(model_response(indoc! { r"
 206            --- a/root/foo.md
 207            +++ b/root/foo.md
 208            @@ ... @@
 209             Hello!
 210            -How
 211            +How are you?
 212             Bye
 213        "}))
 214        .unwrap();
 215
 216    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 217
 218    assert_eq!(prediction.edits.len(), 1);
 219    assert_eq!(
 220        prediction.edits[0].0.to_point(&snapshot).start,
 221        language::Point::new(1, 3)
 222    );
 223    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 224}
 225
 226#[gpui::test]
 227async fn test_request_events(cx: &mut TestAppContext) {
 228    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 229    let fs = FakeFs::new(cx.executor());
 230    fs.insert_tree(
 231        "/root",
 232        json!({
 233            "foo.md": "Hello!\n\nBye\n"
 234        }),
 235    )
 236    .await;
 237    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 238
 239    let buffer = project
 240        .update(cx, |project, cx| {
 241            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 242            project.open_buffer(path, cx)
 243        })
 244        .await
 245        .unwrap();
 246
 247    ep_store.update(cx, |ep_store, cx| {
 248        ep_store.register_buffer(&buffer, &project, cx);
 249    });
 250
 251    buffer.update(cx, |buffer, cx| {
 252        buffer.edit(vec![(7..7, "How")], None, cx);
 253    });
 254
 255    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 256    let position = snapshot.anchor_before(language::Point::new(1, 3));
 257
 258    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 259        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 260    });
 261
 262    let (request, respond_tx) = requests.predict.next().await.unwrap();
 263
 264    let prompt = prompt_from_request(&request);
 265    assert!(
 266        prompt.contains(indoc! {"
 267        --- a/root/foo.md
 268        +++ b/root/foo.md
 269        @@ -1,3 +1,3 @@
 270         Hello!
 271        -
 272        +How
 273         Bye
 274    "}),
 275        "{prompt}"
 276    );
 277
 278    respond_tx
 279        .send(model_response(indoc! {r#"
 280            --- a/root/foo.md
 281            +++ b/root/foo.md
 282            @@ ... @@
 283             Hello!
 284            -How
 285            +How are you?
 286             Bye
 287        "#}))
 288        .unwrap();
 289
 290    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 291
 292    assert_eq!(prediction.edits.len(), 1);
 293    assert_eq!(
 294        prediction.edits[0].0.to_point(&snapshot).start,
 295        language::Point::new(1, 3)
 296    );
 297    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 298}
 299
 300#[gpui::test]
 301async fn test_empty_prediction(cx: &mut TestAppContext) {
 302    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 303    let fs = FakeFs::new(cx.executor());
 304    fs.insert_tree(
 305        "/root",
 306        json!({
 307            "foo.md":  "Hello!\nHow\nBye\n"
 308        }),
 309    )
 310    .await;
 311    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 312
 313    let buffer = project
 314        .update(cx, |project, cx| {
 315            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 316            project.open_buffer(path, cx)
 317        })
 318        .await
 319        .unwrap();
 320    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 321    let position = snapshot.anchor_before(language::Point::new(1, 3));
 322
 323    ep_store.update(cx, |ep_store, cx| {
 324        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 325    });
 326
 327    const NO_OP_DIFF: &str = indoc! { r"
 328        --- a/root/foo.md
 329        +++ b/root/foo.md
 330        @@ ... @@
 331         Hello!
 332        -How
 333        +How
 334         Bye
 335    "};
 336
 337    let (_, respond_tx) = requests.predict.next().await.unwrap();
 338    let response = model_response(NO_OP_DIFF);
 339    let id = response.id.clone();
 340    respond_tx.send(response).unwrap();
 341
 342    cx.run_until_parked();
 343
 344    ep_store.read_with(cx, |ep_store, cx| {
 345        assert!(
 346            ep_store
 347                .current_prediction_for_buffer(&buffer, &project, cx)
 348                .is_none()
 349        );
 350    });
 351
 352    // prediction is reported as rejected
 353    let (reject_request, _) = requests.reject.next().await.unwrap();
 354
 355    assert_eq!(
 356        &reject_request.rejections,
 357        &[EditPredictionRejection {
 358            request_id: id,
 359            reason: EditPredictionRejectReason::Empty,
 360            was_shown: false
 361        }]
 362    );
 363}
 364
 365#[gpui::test]
 366async fn test_interpolated_empty(cx: &mut TestAppContext) {
 367    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 368    let fs = FakeFs::new(cx.executor());
 369    fs.insert_tree(
 370        "/root",
 371        json!({
 372            "foo.md":  "Hello!\nHow\nBye\n"
 373        }),
 374    )
 375    .await;
 376    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 377
 378    let buffer = project
 379        .update(cx, |project, cx| {
 380            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 381            project.open_buffer(path, cx)
 382        })
 383        .await
 384        .unwrap();
 385    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 386    let position = snapshot.anchor_before(language::Point::new(1, 3));
 387
 388    ep_store.update(cx, |ep_store, cx| {
 389        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 390    });
 391
 392    let (_, respond_tx) = requests.predict.next().await.unwrap();
 393
 394    buffer.update(cx, |buffer, cx| {
 395        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 396    });
 397
 398    let response = model_response(SIMPLE_DIFF);
 399    let id = response.id.clone();
 400    respond_tx.send(response).unwrap();
 401
 402    cx.run_until_parked();
 403
 404    ep_store.read_with(cx, |ep_store, cx| {
 405        assert!(
 406            ep_store
 407                .current_prediction_for_buffer(&buffer, &project, cx)
 408                .is_none()
 409        );
 410    });
 411
 412    // prediction is reported as rejected
 413    let (reject_request, _) = requests.reject.next().await.unwrap();
 414
 415    assert_eq!(
 416        &reject_request.rejections,
 417        &[EditPredictionRejection {
 418            request_id: id,
 419            reason: EditPredictionRejectReason::InterpolatedEmpty,
 420            was_shown: false
 421        }]
 422    );
 423}
 424
 425const SIMPLE_DIFF: &str = indoc! { r"
 426    --- a/root/foo.md
 427    +++ b/root/foo.md
 428    @@ ... @@
 429     Hello!
 430    -How
 431    +How are you?
 432     Bye
 433"};
 434
 435#[gpui::test]
 436async fn test_replace_current(cx: &mut TestAppContext) {
 437    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 438    let fs = FakeFs::new(cx.executor());
 439    fs.insert_tree(
 440        "/root",
 441        json!({
 442            "foo.md":  "Hello!\nHow\nBye\n"
 443        }),
 444    )
 445    .await;
 446    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 447
 448    let buffer = project
 449        .update(cx, |project, cx| {
 450            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 451            project.open_buffer(path, cx)
 452        })
 453        .await
 454        .unwrap();
 455    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 456    let position = snapshot.anchor_before(language::Point::new(1, 3));
 457
 458    ep_store.update(cx, |ep_store, cx| {
 459        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 460    });
 461
 462    let (_, respond_tx) = requests.predict.next().await.unwrap();
 463    let first_response = model_response(SIMPLE_DIFF);
 464    let first_id = first_response.id.clone();
 465    respond_tx.send(first_response).unwrap();
 466
 467    cx.run_until_parked();
 468
 469    ep_store.read_with(cx, |ep_store, cx| {
 470        assert_eq!(
 471            ep_store
 472                .current_prediction_for_buffer(&buffer, &project, cx)
 473                .unwrap()
 474                .id
 475                .0,
 476            first_id
 477        );
 478    });
 479
 480    // a second request is triggered
 481    ep_store.update(cx, |ep_store, cx| {
 482        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 483    });
 484
 485    let (_, respond_tx) = requests.predict.next().await.unwrap();
 486    let second_response = model_response(SIMPLE_DIFF);
 487    let second_id = second_response.id.clone();
 488    respond_tx.send(second_response).unwrap();
 489
 490    cx.run_until_parked();
 491
 492    ep_store.read_with(cx, |ep_store, cx| {
 493        // second replaces first
 494        assert_eq!(
 495            ep_store
 496                .current_prediction_for_buffer(&buffer, &project, cx)
 497                .unwrap()
 498                .id
 499                .0,
 500            second_id
 501        );
 502    });
 503
 504    // first is reported as replaced
 505    let (reject_request, _) = requests.reject.next().await.unwrap();
 506
 507    assert_eq!(
 508        &reject_request.rejections,
 509        &[EditPredictionRejection {
 510            request_id: first_id,
 511            reason: EditPredictionRejectReason::Replaced,
 512            was_shown: false
 513        }]
 514    );
 515}
 516
 517#[gpui::test]
 518async fn test_current_preferred(cx: &mut TestAppContext) {
 519    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 520    let fs = FakeFs::new(cx.executor());
 521    fs.insert_tree(
 522        "/root",
 523        json!({
 524            "foo.md":  "Hello!\nHow\nBye\n"
 525        }),
 526    )
 527    .await;
 528    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 529
 530    let buffer = project
 531        .update(cx, |project, cx| {
 532            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 533            project.open_buffer(path, cx)
 534        })
 535        .await
 536        .unwrap();
 537    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 538    let position = snapshot.anchor_before(language::Point::new(1, 3));
 539
 540    ep_store.update(cx, |ep_store, cx| {
 541        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 542    });
 543
 544    let (_, respond_tx) = requests.predict.next().await.unwrap();
 545    let first_response = model_response(SIMPLE_DIFF);
 546    let first_id = first_response.id.clone();
 547    respond_tx.send(first_response).unwrap();
 548
 549    cx.run_until_parked();
 550
 551    ep_store.read_with(cx, |ep_store, cx| {
 552        assert_eq!(
 553            ep_store
 554                .current_prediction_for_buffer(&buffer, &project, cx)
 555                .unwrap()
 556                .id
 557                .0,
 558            first_id
 559        );
 560    });
 561
 562    // a second request is triggered
 563    ep_store.update(cx, |ep_store, cx| {
 564        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 565    });
 566
 567    let (_, respond_tx) = requests.predict.next().await.unwrap();
 568    // worse than current prediction
 569    let second_response = model_response(indoc! { r"
 570        --- a/root/foo.md
 571        +++ b/root/foo.md
 572        @@ ... @@
 573         Hello!
 574        -How
 575        +How are
 576         Bye
 577    "});
 578    let second_id = second_response.id.clone();
 579    respond_tx.send(second_response).unwrap();
 580
 581    cx.run_until_parked();
 582
 583    ep_store.read_with(cx, |ep_store, cx| {
 584        // first is preferred over second
 585        assert_eq!(
 586            ep_store
 587                .current_prediction_for_buffer(&buffer, &project, cx)
 588                .unwrap()
 589                .id
 590                .0,
 591            first_id
 592        );
 593    });
 594
 595    // second is reported as rejected
 596    let (reject_request, _) = requests.reject.next().await.unwrap();
 597
 598    assert_eq!(
 599        &reject_request.rejections,
 600        &[EditPredictionRejection {
 601            request_id: second_id,
 602            reason: EditPredictionRejectReason::CurrentPreferred,
 603            was_shown: false
 604        }]
 605    );
 606}
 607
 608#[gpui::test]
 609async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
 610    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 611    let fs = FakeFs::new(cx.executor());
 612    fs.insert_tree(
 613        "/root",
 614        json!({
 615            "foo.md":  "Hello!\nHow\nBye\n"
 616        }),
 617    )
 618    .await;
 619    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 620
 621    let buffer = project
 622        .update(cx, |project, cx| {
 623            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 624            project.open_buffer(path, cx)
 625        })
 626        .await
 627        .unwrap();
 628    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 629    let position = snapshot.anchor_before(language::Point::new(1, 3));
 630
 631    // start two refresh tasks
 632    ep_store.update(cx, |ep_store, cx| {
 633        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 634    });
 635
 636    let (_, respond_first) = requests.predict.next().await.unwrap();
 637
 638    ep_store.update(cx, |ep_store, cx| {
 639        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 640    });
 641
 642    let (_, respond_second) = requests.predict.next().await.unwrap();
 643
 644    // wait for throttle
 645    cx.run_until_parked();
 646
 647    // second responds first
 648    let second_response = model_response(SIMPLE_DIFF);
 649    let second_id = second_response.id.clone();
 650    respond_second.send(second_response).unwrap();
 651
 652    cx.run_until_parked();
 653
 654    ep_store.read_with(cx, |ep_store, cx| {
 655        // current prediction is second
 656        assert_eq!(
 657            ep_store
 658                .current_prediction_for_buffer(&buffer, &project, cx)
 659                .unwrap()
 660                .id
 661                .0,
 662            second_id
 663        );
 664    });
 665
 666    let first_response = model_response(SIMPLE_DIFF);
 667    let first_id = first_response.id.clone();
 668    respond_first.send(first_response).unwrap();
 669
 670    cx.run_until_parked();
 671
 672    ep_store.read_with(cx, |ep_store, cx| {
 673        // current prediction is still second, since first was cancelled
 674        assert_eq!(
 675            ep_store
 676                .current_prediction_for_buffer(&buffer, &project, cx)
 677                .unwrap()
 678                .id
 679                .0,
 680            second_id
 681        );
 682    });
 683
 684    // first is reported as rejected
 685    let (reject_request, _) = requests.reject.next().await.unwrap();
 686
 687    cx.run_until_parked();
 688
 689    assert_eq!(
 690        &reject_request.rejections,
 691        &[EditPredictionRejection {
 692            request_id: first_id,
 693            reason: EditPredictionRejectReason::Canceled,
 694            was_shown: false
 695        }]
 696    );
 697}
 698
 699#[gpui::test]
 700async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
 701    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 702    let fs = FakeFs::new(cx.executor());
 703    fs.insert_tree(
 704        "/root",
 705        json!({
 706            "foo.md":  "Hello!\nHow\nBye\n"
 707        }),
 708    )
 709    .await;
 710    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 711
 712    let buffer = project
 713        .update(cx, |project, cx| {
 714            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 715            project.open_buffer(path, cx)
 716        })
 717        .await
 718        .unwrap();
 719    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 720    let position = snapshot.anchor_before(language::Point::new(1, 3));
 721
 722    // start two refresh tasks
 723    ep_store.update(cx, |ep_store, cx| {
 724        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 725    });
 726
 727    let (_, respond_first) = requests.predict.next().await.unwrap();
 728
 729    ep_store.update(cx, |ep_store, cx| {
 730        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 731    });
 732
 733    let (_, respond_second) = requests.predict.next().await.unwrap();
 734
 735    // wait for throttle, so requests are sent
 736    cx.run_until_parked();
 737
 738    ep_store.update(cx, |ep_store, cx| {
 739        // start a third request
 740        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 741
 742        // 2 are pending, so 2nd is cancelled
 743        assert_eq!(
 744            ep_store
 745                .get_or_init_project(&project, cx)
 746                .cancelled_predictions
 747                .iter()
 748                .copied()
 749                .collect::<Vec<_>>(),
 750            [1]
 751        );
 752    });
 753
 754    // wait for throttle
 755    cx.run_until_parked();
 756
 757    let (_, respond_third) = requests.predict.next().await.unwrap();
 758
 759    let first_response = model_response(SIMPLE_DIFF);
 760    let first_id = first_response.id.clone();
 761    respond_first.send(first_response).unwrap();
 762
 763    cx.run_until_parked();
 764
 765    ep_store.read_with(cx, |ep_store, cx| {
 766        // current prediction is first
 767        assert_eq!(
 768            ep_store
 769                .current_prediction_for_buffer(&buffer, &project, cx)
 770                .unwrap()
 771                .id
 772                .0,
 773            first_id
 774        );
 775    });
 776
 777    let cancelled_response = model_response(SIMPLE_DIFF);
 778    let cancelled_id = cancelled_response.id.clone();
 779    respond_second.send(cancelled_response).unwrap();
 780
 781    cx.run_until_parked();
 782
 783    ep_store.read_with(cx, |ep_store, cx| {
 784        // current prediction is still first, since second was cancelled
 785        assert_eq!(
 786            ep_store
 787                .current_prediction_for_buffer(&buffer, &project, cx)
 788                .unwrap()
 789                .id
 790                .0,
 791            first_id
 792        );
 793    });
 794
 795    let third_response = model_response(SIMPLE_DIFF);
 796    let third_response_id = third_response.id.clone();
 797    respond_third.send(third_response).unwrap();
 798
 799    cx.run_until_parked();
 800
 801    ep_store.read_with(cx, |ep_store, cx| {
 802        // third completes and replaces first
 803        assert_eq!(
 804            ep_store
 805                .current_prediction_for_buffer(&buffer, &project, cx)
 806                .unwrap()
 807                .id
 808                .0,
 809            third_response_id
 810        );
 811    });
 812
 813    // second is reported as rejected
 814    let (reject_request, _) = requests.reject.next().await.unwrap();
 815
 816    cx.run_until_parked();
 817
 818    assert_eq!(
 819        &reject_request.rejections,
 820        &[
 821            EditPredictionRejection {
 822                request_id: cancelled_id,
 823                reason: EditPredictionRejectReason::Canceled,
 824                was_shown: false
 825            },
 826            EditPredictionRejection {
 827                request_id: first_id,
 828                reason: EditPredictionRejectReason::Replaced,
 829                was_shown: false
 830            }
 831        ]
 832    );
 833}
 834
 835#[gpui::test]
 836async fn test_rejections_flushing(cx: &mut TestAppContext) {
 837    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 838
 839    ep_store.update(cx, |ep_store, _cx| {
 840        ep_store.reject_prediction(
 841            EditPredictionId("test-1".into()),
 842            EditPredictionRejectReason::Discarded,
 843            false,
 844        );
 845        ep_store.reject_prediction(
 846            EditPredictionId("test-2".into()),
 847            EditPredictionRejectReason::Canceled,
 848            true,
 849        );
 850    });
 851
 852    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
 853    cx.run_until_parked();
 854
 855    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
 856    respond_tx.send(()).unwrap();
 857
 858    // batched
 859    assert_eq!(reject_request.rejections.len(), 2);
 860    assert_eq!(
 861        reject_request.rejections[0],
 862        EditPredictionRejection {
 863            request_id: "test-1".to_string(),
 864            reason: EditPredictionRejectReason::Discarded,
 865            was_shown: false
 866        }
 867    );
 868    assert_eq!(
 869        reject_request.rejections[1],
 870        EditPredictionRejection {
 871            request_id: "test-2".to_string(),
 872            reason: EditPredictionRejectReason::Canceled,
 873            was_shown: true
 874        }
 875    );
 876
 877    // Reaching batch size limit sends without debounce
 878    ep_store.update(cx, |ep_store, _cx| {
 879        for i in 0..70 {
 880            ep_store.reject_prediction(
 881                EditPredictionId(format!("batch-{}", i).into()),
 882                EditPredictionRejectReason::Discarded,
 883                false,
 884            );
 885        }
 886    });
 887
 888    // First MAX/2 items are sent immediately
 889    cx.run_until_parked();
 890    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
 891    respond_tx.send(()).unwrap();
 892
 893    assert_eq!(reject_request.rejections.len(), 50);
 894    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
 895    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
 896
 897    // Remaining items are debounced with the next batch
 898    cx.executor().advance_clock(Duration::from_secs(15));
 899    cx.run_until_parked();
 900
 901    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
 902    respond_tx.send(()).unwrap();
 903
 904    assert_eq!(reject_request.rejections.len(), 20);
 905    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
 906    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
 907
 908    // Request failure
 909    ep_store.update(cx, |ep_store, _cx| {
 910        ep_store.reject_prediction(
 911            EditPredictionId("retry-1".into()),
 912            EditPredictionRejectReason::Discarded,
 913            false,
 914        );
 915    });
 916
 917    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
 918    cx.run_until_parked();
 919
 920    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
 921    assert_eq!(reject_request.rejections.len(), 1);
 922    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
 923    // Simulate failure
 924    drop(_respond_tx);
 925
 926    // Add another rejection
 927    ep_store.update(cx, |ep_store, _cx| {
 928        ep_store.reject_prediction(
 929            EditPredictionId("retry-2".into()),
 930            EditPredictionRejectReason::Discarded,
 931            false,
 932        );
 933    });
 934
 935    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
 936    cx.run_until_parked();
 937
 938    // Retry should include both the failed item and the new one
 939    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
 940    respond_tx.send(()).unwrap();
 941
 942    assert_eq!(reject_request.rejections.len(), 2);
 943    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
 944    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
 945}
 946
 947// Skipped until we start including diagnostics in prompt
 948// #[gpui::test]
 949// async fn test_request_diagnostics(cx: &mut TestAppContext) {
 950//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
 951//     let fs = FakeFs::new(cx.executor());
 952//     fs.insert_tree(
 953//         "/root",
 954//         json!({
 955//             "foo.md": "Hello!\nBye"
 956//         }),
 957//     )
 958//     .await;
 959//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 960
 961//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
 962//     let diagnostic = lsp::Diagnostic {
 963//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 964//         severity: Some(lsp::DiagnosticSeverity::ERROR),
 965//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
 966//         ..Default::default()
 967//     };
 968
 969//     project.update(cx, |project, cx| {
 970//         project.lsp_store().update(cx, |lsp_store, cx| {
 971//             // Create some diagnostics
 972//             lsp_store
 973//                 .update_diagnostics(
 974//                     LanguageServerId(0),
 975//                     lsp::PublishDiagnosticsParams {
 976//                         uri: path_to_buffer_uri.clone(),
 977//                         diagnostics: vec![diagnostic],
 978//                         version: None,
 979//                     },
 980//                     None,
 981//                     language::DiagnosticSourceKind::Pushed,
 982//                     &[],
 983//                     cx,
 984//                 )
 985//                 .unwrap();
 986//         });
 987//     });
 988
 989//     let buffer = project
 990//         .update(cx, |project, cx| {
 991//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 992//             project.open_buffer(path, cx)
 993//         })
 994//         .await
 995//         .unwrap();
 996
 997//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 998//     let position = snapshot.anchor_before(language::Point::new(0, 0));
 999
1000//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1001//         ep_store.request_prediction(&project, &buffer, position, cx)
1002//     });
1003
1004//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1005
1006//     assert_eq!(request.diagnostic_groups.len(), 1);
1007//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1008//         .unwrap();
1009//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1010//     assert_eq!(
1011//         value,
1012//         json!({
1013//             "entries": [{
1014//                 "range": {
1015//                     "start": 8,
1016//                     "end": 10
1017//                 },
1018//                 "diagnostic": {
1019//                     "source": null,
1020//                     "code": null,
1021//                     "code_description": null,
1022//                     "severity": 1,
1023//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1024//                     "markdown": null,
1025//                     "group_id": 0,
1026//                     "is_primary": true,
1027//                     "is_disk_based": false,
1028//                     "is_unnecessary": false,
1029//                     "source_kind": "Pushed",
1030//                     "data": null,
1031//                     "underline": true
1032//                 }
1033//             }],
1034//             "primary_ix": 0
1035//         })
1036//     );
1037// }
1038
1039fn model_response(text: &str) -> open_ai::Response {
1040    open_ai::Response {
1041        id: Uuid::new_v4().to_string(),
1042        object: "response".into(),
1043        created: 0,
1044        model: "model".into(),
1045        choices: vec![open_ai::Choice {
1046            index: 0,
1047            message: open_ai::RequestMessage::Assistant {
1048                content: Some(open_ai::MessageContent::Plain(text.to_string())),
1049                tool_calls: vec![],
1050            },
1051            finish_reason: None,
1052        }],
1053        usage: Usage {
1054            prompt_tokens: 0,
1055            completion_tokens: 0,
1056            total_tokens: 0,
1057        },
1058    }
1059}
1060
1061fn prompt_from_request(request: &open_ai::Request) -> &str {
1062    assert_eq!(request.messages.len(), 1);
1063    let open_ai::RequestMessage::User {
1064        content: open_ai::MessageContent::Plain(content),
1065        ..
1066    } = &request.messages[0]
1067    else {
1068        panic!(
1069            "Request does not have single user message of type Plain. {:#?}",
1070            request
1071        );
1072    };
1073    content
1074}
1075
1076struct RequestChannels {
1077    predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
1078    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1079}
1080
1081fn init_test_with_fake_client(
1082    cx: &mut TestAppContext,
1083) -> (Entity<EditPredictionStore>, RequestChannels) {
1084    cx.update(move |cx| {
1085        let settings_store = SettingsStore::test(cx);
1086        cx.set_global(settings_store);
1087        zlog::init_test();
1088
1089        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1090        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1091
1092        let http_client = FakeHttpClient::create({
1093            move |req| {
1094                let uri = req.uri().path().to_string();
1095                let mut body = req.into_body();
1096                let predict_req_tx = predict_req_tx.clone();
1097                let reject_req_tx = reject_req_tx.clone();
1098                async move {
1099                    let resp = match uri.as_str() {
1100                        "/client/llm_tokens" => serde_json::to_string(&json!({
1101                            "token": "test"
1102                        }))
1103                        .unwrap(),
1104                        "/predict_edits/raw" => {
1105                            let mut buf = Vec::new();
1106                            body.read_to_end(&mut buf).await.ok();
1107                            let req = serde_json::from_slice(&buf).unwrap();
1108
1109                            let (res_tx, res_rx) = oneshot::channel();
1110                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1111                            serde_json::to_string(&res_rx.await?).unwrap()
1112                        }
1113                        "/predict_edits/reject" => {
1114                            let mut buf = Vec::new();
1115                            body.read_to_end(&mut buf).await.ok();
1116                            let req = serde_json::from_slice(&buf).unwrap();
1117
1118                            let (res_tx, res_rx) = oneshot::channel();
1119                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1120                            serde_json::to_string(&res_rx.await?).unwrap()
1121                        }
1122                        _ => {
1123                            panic!("Unexpected path: {}", uri)
1124                        }
1125                    };
1126
1127                    Ok(Response::builder().body(resp.into()).unwrap())
1128                }
1129            }
1130        });
1131
1132        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1133        client.cloud_client().set_credentials(1, "test".into());
1134
1135        language_model::init(client.clone(), cx);
1136
1137        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1138        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1139
1140        (
1141            ep_store,
1142            RequestChannels {
1143                predict: predict_req_rx,
1144                reject: reject_req_rx,
1145            },
1146        )
1147    })
1148}
1149
1150const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1151
1152#[gpui::test]
1153async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1154    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1155    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1156        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1157    });
1158
1159    let edit_preview = cx
1160        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1161        .await;
1162
1163    let completion = EditPrediction {
1164        edits,
1165        edit_preview,
1166        buffer: buffer.clone(),
1167        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1168        id: EditPredictionId("the-id".into()),
1169        inputs: EditPredictionInputs {
1170            events: Default::default(),
1171            included_files: Default::default(),
1172            cursor_point: cloud_llm_client::predict_edits_v3::Point {
1173                line: Line(0),
1174                column: 0,
1175            },
1176            cursor_path: Path::new("").into(),
1177        },
1178        buffer_snapshotted_at: Instant::now(),
1179        response_received_at: Instant::now(),
1180    };
1181
1182    cx.update(|cx| {
1183        assert_eq!(
1184            from_completion_edits(
1185                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1186                &buffer,
1187                cx
1188            ),
1189            vec![(2..5, "REM".into()), (9..11, "".into())]
1190        );
1191
1192        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1193        assert_eq!(
1194            from_completion_edits(
1195                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1196                &buffer,
1197                cx
1198            ),
1199            vec![(2..2, "REM".into()), (6..8, "".into())]
1200        );
1201
1202        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1203        assert_eq!(
1204            from_completion_edits(
1205                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1206                &buffer,
1207                cx
1208            ),
1209            vec![(2..5, "REM".into()), (9..11, "".into())]
1210        );
1211
1212        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1213        assert_eq!(
1214            from_completion_edits(
1215                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1216                &buffer,
1217                cx
1218            ),
1219            vec![(3..3, "EM".into()), (7..9, "".into())]
1220        );
1221
1222        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1223        assert_eq!(
1224            from_completion_edits(
1225                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1226                &buffer,
1227                cx
1228            ),
1229            vec![(4..4, "M".into()), (8..10, "".into())]
1230        );
1231
1232        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1233        assert_eq!(
1234            from_completion_edits(
1235                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1236                &buffer,
1237                cx
1238            ),
1239            vec![(9..11, "".into())]
1240        );
1241
1242        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1243        assert_eq!(
1244            from_completion_edits(
1245                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1246                &buffer,
1247                cx
1248            ),
1249            vec![(4..4, "M".into()), (8..10, "".into())]
1250        );
1251
1252        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1253        assert_eq!(
1254            from_completion_edits(
1255                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1256                &buffer,
1257                cx
1258            ),
1259            vec![(4..4, "M".into())]
1260        );
1261
1262        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1263        assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1264    })
1265}
1266
1267#[gpui::test]
1268async fn test_clean_up_diff(cx: &mut TestAppContext) {
1269    init_test(cx);
1270
1271    assert_eq!(
1272        apply_edit_prediction(
1273            indoc! {"
1274                    fn main() {
1275                        let word_1 = \"lorem\";
1276                        let range = word.len()..word.len();
1277                    }
1278                "},
1279            indoc! {"
1280                    <|editable_region_start|>
1281                    fn main() {
1282                        let word_1 = \"lorem\";
1283                        let range = word_1.len()..word_1.len();
1284                    }
1285
1286                    <|editable_region_end|>
1287                "},
1288            cx,
1289        )
1290        .await,
1291        indoc! {"
1292                fn main() {
1293                    let word_1 = \"lorem\";
1294                    let range = word_1.len()..word_1.len();
1295                }
1296            "},
1297    );
1298
1299    assert_eq!(
1300        apply_edit_prediction(
1301            indoc! {"
1302                    fn main() {
1303                        let story = \"the quick\"
1304                    }
1305                "},
1306            indoc! {"
1307                    <|editable_region_start|>
1308                    fn main() {
1309                        let story = \"the quick brown fox jumps over the lazy dog\";
1310                    }
1311
1312                    <|editable_region_end|>
1313                "},
1314            cx,
1315        )
1316        .await,
1317        indoc! {"
1318                fn main() {
1319                    let story = \"the quick brown fox jumps over the lazy dog\";
1320                }
1321            "},
1322    );
1323}
1324
1325#[gpui::test]
1326async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1327    init_test(cx);
1328
1329    let buffer_content = "lorem\n";
1330    let completion_response = indoc! {"
1331            ```animals.js
1332            <|start_of_file|>
1333            <|editable_region_start|>
1334            lorem
1335            ipsum
1336            <|editable_region_end|>
1337            ```"};
1338
1339    assert_eq!(
1340        apply_edit_prediction(buffer_content, completion_response, cx).await,
1341        "lorem\nipsum"
1342    );
1343}
1344
1345#[gpui::test]
1346async fn test_can_collect_data(cx: &mut TestAppContext) {
1347    init_test(cx);
1348
1349    let fs = project::FakeFs::new(cx.executor());
1350    fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1351        .await;
1352
1353    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1354    let buffer = project
1355        .update(cx, |project, cx| {
1356            project.open_local_buffer(path!("/project/src/main.rs"), cx)
1357        })
1358        .await
1359        .unwrap();
1360
1361    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1362    ep_store.update(cx, |ep_store, _cx| {
1363        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1364    });
1365
1366    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1367    assert_eq!(
1368        captured_request.lock().clone().unwrap().can_collect_data,
1369        true
1370    );
1371
1372    ep_store.update(cx, |ep_store, _cx| {
1373        ep_store.data_collection_choice = DataCollectionChoice::Disabled
1374    });
1375
1376    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1377    assert_eq!(
1378        captured_request.lock().clone().unwrap().can_collect_data,
1379        false
1380    );
1381}
1382
1383#[gpui::test]
1384async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1385    init_test(cx);
1386
1387    let fs = project::FakeFs::new(cx.executor());
1388    let project = Project::test(fs.clone(), [], cx).await;
1389
1390    let buffer = cx.new(|_cx| {
1391        Buffer::remote(
1392            language::BufferId::new(1).unwrap(),
1393            ReplicaId::new(1),
1394            language::Capability::ReadWrite,
1395            "fn main() {\n    println!(\"Hello\");\n}",
1396        )
1397    });
1398
1399    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1400    ep_store.update(cx, |ep_store, _cx| {
1401        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1402    });
1403
1404    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1405    assert_eq!(
1406        captured_request.lock().clone().unwrap().can_collect_data,
1407        false
1408    );
1409}
1410
1411#[gpui::test]
1412async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1413    init_test(cx);
1414
1415    let fs = project::FakeFs::new(cx.executor());
1416    fs.insert_tree(
1417        path!("/project"),
1418        json!({
1419            "LICENSE": BSD_0_TXT,
1420            ".env": "SECRET_KEY=secret"
1421        }),
1422    )
1423    .await;
1424
1425    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1426    let buffer = project
1427        .update(cx, |project, cx| {
1428            project.open_local_buffer("/project/.env", cx)
1429        })
1430        .await
1431        .unwrap();
1432
1433    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1434    ep_store.update(cx, |ep_store, _cx| {
1435        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1436    });
1437
1438    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1439    assert_eq!(
1440        captured_request.lock().clone().unwrap().can_collect_data,
1441        false
1442    );
1443}
1444
1445#[gpui::test]
1446async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1447    init_test(cx);
1448
1449    let fs = project::FakeFs::new(cx.executor());
1450    let project = Project::test(fs.clone(), [], cx).await;
1451    let buffer = cx.new(|cx| Buffer::local("", cx));
1452
1453    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1454    ep_store.update(cx, |ep_store, _cx| {
1455        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1456    });
1457
1458    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1459    assert_eq!(
1460        captured_request.lock().clone().unwrap().can_collect_data,
1461        false
1462    );
1463}
1464
1465#[gpui::test]
1466async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1467    init_test(cx);
1468
1469    let fs = project::FakeFs::new(cx.executor());
1470    fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1471        .await;
1472
1473    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1474    let buffer = project
1475        .update(cx, |project, cx| {
1476            project.open_local_buffer("/project/main.rs", cx)
1477        })
1478        .await
1479        .unwrap();
1480
1481    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1482    ep_store.update(cx, |ep_store, _cx| {
1483        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1484    });
1485
1486    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1487    assert_eq!(
1488        captured_request.lock().clone().unwrap().can_collect_data,
1489        false
1490    );
1491}
1492
1493#[gpui::test]
1494async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1495    init_test(cx);
1496
1497    let fs = project::FakeFs::new(cx.executor());
1498    fs.insert_tree(
1499        path!("/open_source_worktree"),
1500        json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1501    )
1502    .await;
1503    fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1504        .await;
1505
1506    let project = Project::test(
1507        fs.clone(),
1508        [
1509            path!("/open_source_worktree").as_ref(),
1510            path!("/closed_source_worktree").as_ref(),
1511        ],
1512        cx,
1513    )
1514    .await;
1515    let buffer = project
1516        .update(cx, |project, cx| {
1517            project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1518        })
1519        .await
1520        .unwrap();
1521
1522    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1523    ep_store.update(cx, |ep_store, _cx| {
1524        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1525    });
1526
1527    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1528    assert_eq!(
1529        captured_request.lock().clone().unwrap().can_collect_data,
1530        true
1531    );
1532
1533    let closed_source_file = project
1534        .update(cx, |project, cx| {
1535            let worktree2 = project
1536                .worktree_for_root_name("closed_source_worktree", cx)
1537                .unwrap();
1538            worktree2.update(cx, |worktree2, cx| {
1539                worktree2.load_file(rel_path("main.rs"), cx)
1540            })
1541        })
1542        .await
1543        .unwrap()
1544        .file;
1545
1546    buffer.update(cx, |buffer, cx| {
1547        buffer.file_updated(closed_source_file, cx);
1548    });
1549
1550    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1551    assert_eq!(
1552        captured_request.lock().clone().unwrap().can_collect_data,
1553        false
1554    );
1555}
1556
1557#[gpui::test]
1558async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1559    init_test(cx);
1560
1561    let fs = project::FakeFs::new(cx.executor());
1562    fs.insert_tree(
1563        path!("/worktree1"),
1564        json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1565    )
1566    .await;
1567    fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1568        .await;
1569
1570    let project = Project::test(
1571        fs.clone(),
1572        [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1573        cx,
1574    )
1575    .await;
1576    let buffer = project
1577        .update(cx, |project, cx| {
1578            project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1579        })
1580        .await
1581        .unwrap();
1582    let private_buffer = project
1583        .update(cx, |project, cx| {
1584            project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1585        })
1586        .await
1587        .unwrap();
1588
1589    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1590    ep_store.update(cx, |ep_store, _cx| {
1591        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1592    });
1593
1594    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1595    assert_eq!(
1596        captured_request.lock().clone().unwrap().can_collect_data,
1597        true
1598    );
1599
1600    // this has a side effect of registering the buffer to watch for edits
1601    run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1602    assert_eq!(
1603        captured_request.lock().clone().unwrap().can_collect_data,
1604        false
1605    );
1606
1607    private_buffer.update(cx, |private_buffer, cx| {
1608        private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1609    });
1610
1611    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1612    assert_eq!(
1613        captured_request.lock().clone().unwrap().can_collect_data,
1614        false
1615    );
1616
1617    // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1618    // included
1619    buffer.update(cx, |buffer, cx| {
1620        buffer.edit(
1621            [(
1622                0..0,
1623                " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1624            )],
1625            None,
1626            cx,
1627        );
1628    });
1629
1630    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1631    assert_eq!(
1632        captured_request.lock().clone().unwrap().can_collect_data,
1633        true
1634    );
1635}
1636
1637fn init_test(cx: &mut TestAppContext) {
1638    cx.update(|cx| {
1639        let settings_store = SettingsStore::test(cx);
1640        cx.set_global(settings_store);
1641    });
1642}
1643
1644async fn apply_edit_prediction(
1645    buffer_content: &str,
1646    completion_response: &str,
1647    cx: &mut TestAppContext,
1648) -> String {
1649    let fs = project::FakeFs::new(cx.executor());
1650    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1651    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1652    let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1653    *response.lock() = completion_response.to_string();
1654    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1655    buffer.update(cx, |buffer, cx| {
1656        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1657    });
1658    buffer.read_with(cx, |buffer, _| buffer.text())
1659}
1660
1661async fn run_edit_prediction(
1662    buffer: &Entity<Buffer>,
1663    project: &Entity<Project>,
1664    ep_store: &Entity<EditPredictionStore>,
1665    cx: &mut TestAppContext,
1666) -> EditPrediction {
1667    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1668    ep_store.update(cx, |ep_store, cx| {
1669        ep_store.register_buffer(buffer, &project, cx)
1670    });
1671    cx.background_executor.run_until_parked();
1672    let prediction_task = ep_store.update(cx, |ep_store, cx| {
1673        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
1674    });
1675    prediction_task.await.unwrap().unwrap().prediction.unwrap()
1676}
1677
1678async fn make_test_ep_store(
1679    project: &Entity<Project>,
1680    cx: &mut TestAppContext,
1681) -> (
1682    Entity<EditPredictionStore>,
1683    Arc<Mutex<Option<PredictEditsBody>>>,
1684    Arc<Mutex<String>>,
1685) {
1686    let default_response = indoc! {"
1687            ```main.rs
1688            <|start_of_file|>
1689            <|editable_region_start|>
1690            hello world
1691            <|editable_region_end|>
1692            ```"
1693    };
1694    let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
1695    let completion_response: Arc<Mutex<String>> =
1696        Arc::new(Mutex::new(default_response.to_string()));
1697    let http_client = FakeHttpClient::create({
1698        let captured_request = captured_request.clone();
1699        let completion_response = completion_response.clone();
1700        let mut next_request_id = 0;
1701        move |req| {
1702            let captured_request = captured_request.clone();
1703            let completion_response = completion_response.clone();
1704            async move {
1705                match (req.method(), req.uri().path()) {
1706                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1707                        .status(200)
1708                        .body(
1709                            serde_json::to_string(&CreateLlmTokenResponse {
1710                                token: LlmToken("the-llm-token".to_string()),
1711                            })
1712                            .unwrap()
1713                            .into(),
1714                        )
1715                        .unwrap()),
1716                    (&Method::POST, "/predict_edits/v2") => {
1717                        let mut request_body = String::new();
1718                        req.into_body().read_to_string(&mut request_body).await?;
1719                        *captured_request.lock() =
1720                            Some(serde_json::from_str(&request_body).unwrap());
1721                        next_request_id += 1;
1722                        Ok(http_client::Response::builder()
1723                            .status(200)
1724                            .body(
1725                                serde_json::to_string(&PredictEditsResponse {
1726                                    request_id: format!("request-{next_request_id}"),
1727                                    output_excerpt: completion_response.lock().clone(),
1728                                })
1729                                .unwrap()
1730                                .into(),
1731                            )
1732                            .unwrap())
1733                    }
1734                    _ => Ok(http_client::Response::builder()
1735                        .status(404)
1736                        .body("Not Found".into())
1737                        .unwrap()),
1738                }
1739            }
1740        }
1741    });
1742
1743    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1744    cx.update(|cx| {
1745        RefreshLlmTokenListener::register(client.clone(), cx);
1746    });
1747    let _server = FakeServer::for_client(42, &client, cx).await;
1748
1749    let ep_store = cx.new(|cx| {
1750        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
1751        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
1752
1753        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
1754        for worktree in worktrees {
1755            let worktree_id = worktree.read(cx).id();
1756            ep_store
1757                .get_or_init_project(project, cx)
1758                .license_detection_watchers
1759                .entry(worktree_id)
1760                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
1761        }
1762
1763        ep_store
1764    });
1765
1766    (ep_store, captured_request, completion_response)
1767}
1768
1769fn to_completion_edits(
1770    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
1771    buffer: &Entity<Buffer>,
1772    cx: &App,
1773) -> Vec<(Range<Anchor>, Arc<str>)> {
1774    let buffer = buffer.read(cx);
1775    iterator
1776        .into_iter()
1777        .map(|(range, text)| {
1778            (
1779                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1780                text,
1781            )
1782        })
1783        .collect()
1784}
1785
1786fn from_completion_edits(
1787    editor_edits: &[(Range<Anchor>, Arc<str>)],
1788    buffer: &Entity<Buffer>,
1789    cx: &App,
1790) -> Vec<(Range<usize>, Arc<str>)> {
1791    let buffer = buffer.read(cx);
1792    editor_edits
1793        .iter()
1794        .map(|(range, text)| {
1795            (
1796                range.start.to_offset(buffer)..range.end.to_offset(buffer),
1797                text.clone(),
1798            )
1799        })
1800        .collect()
1801}
1802
1803#[ctor::ctor]
1804fn init_logger() {
1805    zlog::init_test();
1806}