edit_prediction_tests.rs

   1use super::*;
   2use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
   3use client::{UserStore, test::FakeServer};
   4use clock::{FakeSystemClock, ReplicaId};
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
   8    RejectEditPredictionsBody,
   9};
  10use futures::{
  11    AsyncReadExt, StreamExt,
  12    channel::{mpsc, oneshot},
  13};
  14use gpui::{
  15    Entity, TestAppContext,
  16    http_client::{FakeHttpClient, Response},
  17};
  18use indoc::indoc;
  19use language::{Point, ToOffset as _};
  20use lsp::LanguageServerId;
  21use open_ai::Usage;
  22use parking_lot::Mutex;
  23use pretty_assertions::{assert_eq, assert_matches};
  24use project::{FakeFs, Project};
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{path::Path, sync::Arc, time::Duration};
  28use util::{path, rel_path::rel_path};
  29use uuid::Uuid;
  30use zeta_prompt::ZetaPromptInput;
  31
  32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  33
  34#[gpui::test]
  35async fn test_current_state(cx: &mut TestAppContext) {
  36    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  37    let fs = FakeFs::new(cx.executor());
  38    fs.insert_tree(
  39        "/root",
  40        json!({
  41            "1.txt": "Hello!\nHow\nBye\n",
  42            "2.txt": "Hola!\nComo\nAdios\n"
  43        }),
  44    )
  45    .await;
  46    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  47
  48    ep_store.update(cx, |ep_store, cx| {
  49        ep_store.register_project(&project, cx);
  50    });
  51
  52    let buffer1 = project
  53        .update(cx, |project, cx| {
  54            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  55            project.set_active_path(Some(path.clone()), cx);
  56            project.open_buffer(path, cx)
  57        })
  58        .await
  59        .unwrap();
  60    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  61    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  62
  63    // Prediction for current file
  64
  65    ep_store.update(cx, |ep_store, cx| {
  66        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  67    });
  68    let (request, respond_tx) = requests.predict.next().await.unwrap();
  69
  70    respond_tx
  71        .send(model_response(
  72            request,
  73            indoc! {r"
  74                --- a/root/1.txt
  75                +++ b/root/1.txt
  76                @@ ... @@
  77                 Hello!
  78                -How
  79                +How are you?
  80                 Bye
  81            "},
  82        ))
  83        .unwrap();
  84
  85    cx.run_until_parked();
  86
  87    ep_store.read_with(cx, |ep_store, cx| {
  88        let prediction = ep_store
  89            .current_prediction_for_buffer(&buffer1, &project, cx)
  90            .unwrap();
  91        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  92    });
  93
  94    ep_store.update(cx, |ep_store, _cx| {
  95        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
  96    });
  97
  98    // Prediction for diagnostic in another file
  99
 100    let diagnostic = lsp::Diagnostic {
 101        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 102        severity: Some(lsp::DiagnosticSeverity::ERROR),
 103        message: "Sentence is incomplete".to_string(),
 104        ..Default::default()
 105    };
 106
 107    project.update(cx, |project, cx| {
 108        project.lsp_store().update(cx, |lsp_store, cx| {
 109            lsp_store
 110                .update_diagnostics(
 111                    LanguageServerId(0),
 112                    lsp::PublishDiagnosticsParams {
 113                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 114                        diagnostics: vec![diagnostic],
 115                        version: None,
 116                    },
 117                    None,
 118                    language::DiagnosticSourceKind::Pushed,
 119                    &[],
 120                    cx,
 121                )
 122                .unwrap();
 123        });
 124    });
 125
 126    let (request, respond_tx) = requests.predict.next().await.unwrap();
 127    respond_tx
 128        .send(model_response(
 129            request,
 130            indoc! {r#"
 131                --- a/root/2.txt
 132                +++ b/root/2.txt
 133                @@ ... @@
 134                 Hola!
 135                -Como
 136                +Como estas?
 137                 Adios
 138            "#},
 139        ))
 140        .unwrap();
 141    cx.run_until_parked();
 142
 143    ep_store.read_with(cx, |ep_store, cx| {
 144        let prediction = ep_store
 145            .current_prediction_for_buffer(&buffer1, &project, cx)
 146            .unwrap();
 147        assert_matches!(
 148            prediction,
 149            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 150        );
 151    });
 152
 153    let buffer2 = project
 154        .update(cx, |project, cx| {
 155            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 156            project.open_buffer(path, cx)
 157        })
 158        .await
 159        .unwrap();
 160
 161    ep_store.read_with(cx, |ep_store, cx| {
 162        let prediction = ep_store
 163            .current_prediction_for_buffer(&buffer2, &project, cx)
 164            .unwrap();
 165        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 166    });
 167}
 168
 169#[gpui::test]
 170async fn test_simple_request(cx: &mut TestAppContext) {
 171    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 172    let fs = FakeFs::new(cx.executor());
 173    fs.insert_tree(
 174        "/root",
 175        json!({
 176            "foo.md":  "Hello!\nHow\nBye\n"
 177        }),
 178    )
 179    .await;
 180    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 181
 182    let buffer = project
 183        .update(cx, |project, cx| {
 184            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 185            project.open_buffer(path, cx)
 186        })
 187        .await
 188        .unwrap();
 189    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 190    let position = snapshot.anchor_before(language::Point::new(1, 3));
 191
 192    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 193        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 194    });
 195
 196    let (request, respond_tx) = requests.predict.next().await.unwrap();
 197
 198    // TODO Put back when we have a structured request again
 199    // assert_eq!(
 200    //     request.excerpt_path.as_ref(),
 201    //     Path::new(path!("root/foo.md"))
 202    // );
 203    // assert_eq!(
 204    //     request.cursor_point,
 205    //     Point {
 206    //         line: Line(1),
 207    //         column: 3
 208    //     }
 209    // );
 210
 211    respond_tx
 212        .send(model_response(
 213            request,
 214            indoc! { r"
 215                --- a/root/foo.md
 216                +++ b/root/foo.md
 217                @@ ... @@
 218                 Hello!
 219                -How
 220                +How are you?
 221                 Bye
 222            "},
 223        ))
 224        .unwrap();
 225
 226    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 227
 228    assert_eq!(prediction.edits.len(), 1);
 229    assert_eq!(
 230        prediction.edits[0].0.to_point(&snapshot).start,
 231        language::Point::new(1, 3)
 232    );
 233    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 234}
 235
 236#[gpui::test]
 237async fn test_request_events(cx: &mut TestAppContext) {
 238    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 239    let fs = FakeFs::new(cx.executor());
 240    fs.insert_tree(
 241        "/root",
 242        json!({
 243            "foo.md": "Hello!\n\nBye\n"
 244        }),
 245    )
 246    .await;
 247    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 248
 249    let buffer = project
 250        .update(cx, |project, cx| {
 251            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 252            project.open_buffer(path, cx)
 253        })
 254        .await
 255        .unwrap();
 256
 257    ep_store.update(cx, |ep_store, cx| {
 258        ep_store.register_buffer(&buffer, &project, cx);
 259    });
 260
 261    buffer.update(cx, |buffer, cx| {
 262        buffer.edit(vec![(7..7, "How")], None, cx);
 263    });
 264
 265    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 266    let position = snapshot.anchor_before(language::Point::new(1, 3));
 267
 268    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 269        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 270    });
 271
 272    let (request, respond_tx) = requests.predict.next().await.unwrap();
 273
 274    let prompt = prompt_from_request(&request);
 275    assert!(
 276        prompt.contains(indoc! {"
 277        --- a/root/foo.md
 278        +++ b/root/foo.md
 279        @@ -1,3 +1,3 @@
 280         Hello!
 281        -
 282        +How
 283         Bye
 284    "}),
 285        "{prompt}"
 286    );
 287
 288    respond_tx
 289        .send(model_response(
 290            request,
 291            indoc! {r#"
 292                --- a/root/foo.md
 293                +++ b/root/foo.md
 294                @@ ... @@
 295                 Hello!
 296                -How
 297                +How are you?
 298                 Bye
 299        "#},
 300        ))
 301        .unwrap();
 302
 303    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 304
 305    assert_eq!(prediction.edits.len(), 1);
 306    assert_eq!(
 307        prediction.edits[0].0.to_point(&snapshot).start,
 308        language::Point::new(1, 3)
 309    );
 310    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 311}
 312
 313#[gpui::test]
 314async fn test_empty_prediction(cx: &mut TestAppContext) {
 315    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 316    let fs = FakeFs::new(cx.executor());
 317    fs.insert_tree(
 318        "/root",
 319        json!({
 320            "foo.md":  "Hello!\nHow\nBye\n"
 321        }),
 322    )
 323    .await;
 324    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 325
 326    let buffer = project
 327        .update(cx, |project, cx| {
 328            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 329            project.open_buffer(path, cx)
 330        })
 331        .await
 332        .unwrap();
 333    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 334    let position = snapshot.anchor_before(language::Point::new(1, 3));
 335
 336    ep_store.update(cx, |ep_store, cx| {
 337        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 338    });
 339
 340    let (request, respond_tx) = requests.predict.next().await.unwrap();
 341    let response = model_response(request, "");
 342    let id = response.id.clone();
 343    respond_tx.send(response).unwrap();
 344
 345    cx.run_until_parked();
 346
 347    ep_store.read_with(cx, |ep_store, cx| {
 348        assert!(
 349            ep_store
 350                .current_prediction_for_buffer(&buffer, &project, cx)
 351                .is_none()
 352        );
 353    });
 354
 355    // prediction is reported as rejected
 356    let (reject_request, _) = requests.reject.next().await.unwrap();
 357
 358    assert_eq!(
 359        &reject_request.rejections,
 360        &[EditPredictionRejection {
 361            request_id: id,
 362            reason: EditPredictionRejectReason::Empty,
 363            was_shown: false
 364        }]
 365    );
 366}
 367
 368#[gpui::test]
 369async fn test_interpolated_empty(cx: &mut TestAppContext) {
 370    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 371    let fs = FakeFs::new(cx.executor());
 372    fs.insert_tree(
 373        "/root",
 374        json!({
 375            "foo.md":  "Hello!\nHow\nBye\n"
 376        }),
 377    )
 378    .await;
 379    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 380
 381    let buffer = project
 382        .update(cx, |project, cx| {
 383            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 384            project.open_buffer(path, cx)
 385        })
 386        .await
 387        .unwrap();
 388    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 389    let position = snapshot.anchor_before(language::Point::new(1, 3));
 390
 391    ep_store.update(cx, |ep_store, cx| {
 392        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 393    });
 394
 395    let (request, respond_tx) = requests.predict.next().await.unwrap();
 396
 397    buffer.update(cx, |buffer, cx| {
 398        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 399    });
 400
 401    let response = model_response(request, SIMPLE_DIFF);
 402    let id = response.id.clone();
 403    respond_tx.send(response).unwrap();
 404
 405    cx.run_until_parked();
 406
 407    ep_store.read_with(cx, |ep_store, cx| {
 408        assert!(
 409            ep_store
 410                .current_prediction_for_buffer(&buffer, &project, cx)
 411                .is_none()
 412        );
 413    });
 414
 415    // prediction is reported as rejected
 416    let (reject_request, _) = requests.reject.next().await.unwrap();
 417
 418    assert_eq!(
 419        &reject_request.rejections,
 420        &[EditPredictionRejection {
 421            request_id: id,
 422            reason: EditPredictionRejectReason::InterpolatedEmpty,
 423            was_shown: false
 424        }]
 425    );
 426}
 427
 428const SIMPLE_DIFF: &str = indoc! { r"
 429    --- a/root/foo.md
 430    +++ b/root/foo.md
 431    @@ ... @@
 432     Hello!
 433    -How
 434    +How are you?
 435     Bye
 436"};
 437
 438#[gpui::test]
 439async fn test_replace_current(cx: &mut TestAppContext) {
 440    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 441    let fs = FakeFs::new(cx.executor());
 442    fs.insert_tree(
 443        "/root",
 444        json!({
 445            "foo.md":  "Hello!\nHow\nBye\n"
 446        }),
 447    )
 448    .await;
 449    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 450
 451    let buffer = project
 452        .update(cx, |project, cx| {
 453            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 454            project.open_buffer(path, cx)
 455        })
 456        .await
 457        .unwrap();
 458    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 459    let position = snapshot.anchor_before(language::Point::new(1, 3));
 460
 461    ep_store.update(cx, |ep_store, cx| {
 462        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 463    });
 464
 465    let (request, respond_tx) = requests.predict.next().await.unwrap();
 466    let first_response = model_response(request, SIMPLE_DIFF);
 467    let first_id = first_response.id.clone();
 468    respond_tx.send(first_response).unwrap();
 469
 470    cx.run_until_parked();
 471
 472    ep_store.read_with(cx, |ep_store, cx| {
 473        assert_eq!(
 474            ep_store
 475                .current_prediction_for_buffer(&buffer, &project, cx)
 476                .unwrap()
 477                .id
 478                .0,
 479            first_id
 480        );
 481    });
 482
 483    // a second request is triggered
 484    ep_store.update(cx, |ep_store, cx| {
 485        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 486    });
 487
 488    let (request, respond_tx) = requests.predict.next().await.unwrap();
 489    let second_response = model_response(request, SIMPLE_DIFF);
 490    let second_id = second_response.id.clone();
 491    respond_tx.send(second_response).unwrap();
 492
 493    cx.run_until_parked();
 494
 495    ep_store.read_with(cx, |ep_store, cx| {
 496        // second replaces first
 497        assert_eq!(
 498            ep_store
 499                .current_prediction_for_buffer(&buffer, &project, cx)
 500                .unwrap()
 501                .id
 502                .0,
 503            second_id
 504        );
 505    });
 506
 507    // first is reported as replaced
 508    let (reject_request, _) = requests.reject.next().await.unwrap();
 509
 510    assert_eq!(
 511        &reject_request.rejections,
 512        &[EditPredictionRejection {
 513            request_id: first_id,
 514            reason: EditPredictionRejectReason::Replaced,
 515            was_shown: false
 516        }]
 517    );
 518}
 519
 520#[gpui::test]
 521async fn test_current_preferred(cx: &mut TestAppContext) {
 522    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 523    let fs = FakeFs::new(cx.executor());
 524    fs.insert_tree(
 525        "/root",
 526        json!({
 527            "foo.md":  "Hello!\nHow\nBye\n"
 528        }),
 529    )
 530    .await;
 531    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 532
 533    let buffer = project
 534        .update(cx, |project, cx| {
 535            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 536            project.open_buffer(path, cx)
 537        })
 538        .await
 539        .unwrap();
 540    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 541    let position = snapshot.anchor_before(language::Point::new(1, 3));
 542
 543    ep_store.update(cx, |ep_store, cx| {
 544        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 545    });
 546
 547    let (request, respond_tx) = requests.predict.next().await.unwrap();
 548    let first_response = model_response(request, SIMPLE_DIFF);
 549    let first_id = first_response.id.clone();
 550    respond_tx.send(first_response).unwrap();
 551
 552    cx.run_until_parked();
 553
 554    ep_store.read_with(cx, |ep_store, cx| {
 555        assert_eq!(
 556            ep_store
 557                .current_prediction_for_buffer(&buffer, &project, cx)
 558                .unwrap()
 559                .id
 560                .0,
 561            first_id
 562        );
 563    });
 564
 565    // a second request is triggered
 566    ep_store.update(cx, |ep_store, cx| {
 567        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 568    });
 569
 570    let (request, respond_tx) = requests.predict.next().await.unwrap();
 571    // worse than current prediction
 572    let second_response = model_response(
 573        request,
 574        indoc! { r"
 575            --- a/root/foo.md
 576            +++ b/root/foo.md
 577            @@ ... @@
 578             Hello!
 579            -How
 580            +How are
 581             Bye
 582        "},
 583    );
 584    let second_id = second_response.id.clone();
 585    respond_tx.send(second_response).unwrap();
 586
 587    cx.run_until_parked();
 588
 589    ep_store.read_with(cx, |ep_store, cx| {
 590        // first is preferred over second
 591        assert_eq!(
 592            ep_store
 593                .current_prediction_for_buffer(&buffer, &project, cx)
 594                .unwrap()
 595                .id
 596                .0,
 597            first_id
 598        );
 599    });
 600
 601    // second is reported as rejected
 602    let (reject_request, _) = requests.reject.next().await.unwrap();
 603
 604    assert_eq!(
 605        &reject_request.rejections,
 606        &[EditPredictionRejection {
 607            request_id: second_id,
 608            reason: EditPredictionRejectReason::CurrentPreferred,
 609            was_shown: false
 610        }]
 611    );
 612}
 613
 614#[gpui::test]
 615async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
 616    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 617    let fs = FakeFs::new(cx.executor());
 618    fs.insert_tree(
 619        "/root",
 620        json!({
 621            "foo.md":  "Hello!\nHow\nBye\n"
 622        }),
 623    )
 624    .await;
 625    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 626
 627    let buffer = project
 628        .update(cx, |project, cx| {
 629            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 630            project.open_buffer(path, cx)
 631        })
 632        .await
 633        .unwrap();
 634    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 635    let position = snapshot.anchor_before(language::Point::new(1, 3));
 636
 637    // start two refresh tasks
 638    ep_store.update(cx, |ep_store, cx| {
 639        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 640    });
 641
 642    let (request1, respond_first) = requests.predict.next().await.unwrap();
 643
 644    ep_store.update(cx, |ep_store, cx| {
 645        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 646    });
 647
 648    let (request, respond_second) = requests.predict.next().await.unwrap();
 649
 650    // wait for throttle
 651    cx.run_until_parked();
 652
 653    // second responds first
 654    let second_response = model_response(request, SIMPLE_DIFF);
 655    let second_id = second_response.id.clone();
 656    respond_second.send(second_response).unwrap();
 657
 658    cx.run_until_parked();
 659
 660    ep_store.read_with(cx, |ep_store, cx| {
 661        // current prediction is second
 662        assert_eq!(
 663            ep_store
 664                .current_prediction_for_buffer(&buffer, &project, cx)
 665                .unwrap()
 666                .id
 667                .0,
 668            second_id
 669        );
 670    });
 671
 672    let first_response = model_response(request1, SIMPLE_DIFF);
 673    let first_id = first_response.id.clone();
 674    respond_first.send(first_response).unwrap();
 675
 676    cx.run_until_parked();
 677
 678    ep_store.read_with(cx, |ep_store, cx| {
 679        // current prediction is still second, since first was cancelled
 680        assert_eq!(
 681            ep_store
 682                .current_prediction_for_buffer(&buffer, &project, cx)
 683                .unwrap()
 684                .id
 685                .0,
 686            second_id
 687        );
 688    });
 689
 690    // first is reported as rejected
 691    let (reject_request, _) = requests.reject.next().await.unwrap();
 692
 693    cx.run_until_parked();
 694
 695    assert_eq!(
 696        &reject_request.rejections,
 697        &[EditPredictionRejection {
 698            request_id: first_id,
 699            reason: EditPredictionRejectReason::Canceled,
 700            was_shown: false
 701        }]
 702    );
 703}
 704
 705#[gpui::test]
 706async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
 707    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 708    let fs = FakeFs::new(cx.executor());
 709    fs.insert_tree(
 710        "/root",
 711        json!({
 712            "foo.md":  "Hello!\nHow\nBye\n"
 713        }),
 714    )
 715    .await;
 716    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 717
 718    let buffer = project
 719        .update(cx, |project, cx| {
 720            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 721            project.open_buffer(path, cx)
 722        })
 723        .await
 724        .unwrap();
 725    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 726    let position = snapshot.anchor_before(language::Point::new(1, 3));
 727
 728    // start two refresh tasks
 729    ep_store.update(cx, |ep_store, cx| {
 730        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 731    });
 732
 733    let (request1, respond_first) = requests.predict.next().await.unwrap();
 734
 735    ep_store.update(cx, |ep_store, cx| {
 736        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 737    });
 738
 739    let (request2, respond_second) = requests.predict.next().await.unwrap();
 740
 741    // wait for throttle, so requests are sent
 742    cx.run_until_parked();
 743
 744    ep_store.update(cx, |ep_store, cx| {
 745        // start a third request
 746        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 747
 748        // 2 are pending, so 2nd is cancelled
 749        assert_eq!(
 750            ep_store
 751                .get_or_init_project(&project, cx)
 752                .cancelled_predictions
 753                .iter()
 754                .copied()
 755                .collect::<Vec<_>>(),
 756            [1]
 757        );
 758    });
 759
 760    // wait for throttle
 761    cx.run_until_parked();
 762
 763    let (request3, respond_third) = requests.predict.next().await.unwrap();
 764
 765    let first_response = model_response(request1, SIMPLE_DIFF);
 766    let first_id = first_response.id.clone();
 767    respond_first.send(first_response).unwrap();
 768
 769    cx.run_until_parked();
 770
 771    ep_store.read_with(cx, |ep_store, cx| {
 772        // current prediction is first
 773        assert_eq!(
 774            ep_store
 775                .current_prediction_for_buffer(&buffer, &project, cx)
 776                .unwrap()
 777                .id
 778                .0,
 779            first_id
 780        );
 781    });
 782
 783    let cancelled_response = model_response(request2, SIMPLE_DIFF);
 784    let cancelled_id = cancelled_response.id.clone();
 785    respond_second.send(cancelled_response).unwrap();
 786
 787    cx.run_until_parked();
 788
 789    ep_store.read_with(cx, |ep_store, cx| {
 790        // current prediction is still first, since second was cancelled
 791        assert_eq!(
 792            ep_store
 793                .current_prediction_for_buffer(&buffer, &project, cx)
 794                .unwrap()
 795                .id
 796                .0,
 797            first_id
 798        );
 799    });
 800
 801    let third_response = model_response(request3, SIMPLE_DIFF);
 802    let third_response_id = third_response.id.clone();
 803    respond_third.send(third_response).unwrap();
 804
 805    cx.run_until_parked();
 806
 807    ep_store.read_with(cx, |ep_store, cx| {
 808        // third completes and replaces first
 809        assert_eq!(
 810            ep_store
 811                .current_prediction_for_buffer(&buffer, &project, cx)
 812                .unwrap()
 813                .id
 814                .0,
 815            third_response_id
 816        );
 817    });
 818
 819    // second is reported as rejected
 820    let (reject_request, _) = requests.reject.next().await.unwrap();
 821
 822    cx.run_until_parked();
 823
 824    assert_eq!(
 825        &reject_request.rejections,
 826        &[
 827            EditPredictionRejection {
 828                request_id: cancelled_id,
 829                reason: EditPredictionRejectReason::Canceled,
 830                was_shown: false
 831            },
 832            EditPredictionRejection {
 833                request_id: first_id,
 834                reason: EditPredictionRejectReason::Replaced,
 835                was_shown: false
 836            }
 837        ]
 838    );
 839}
 840
 841#[gpui::test]
 842async fn test_rejections_flushing(cx: &mut TestAppContext) {
 843    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 844
 845    ep_store.update(cx, |ep_store, _cx| {
 846        ep_store.reject_prediction(
 847            EditPredictionId("test-1".into()),
 848            EditPredictionRejectReason::Discarded,
 849            false,
 850        );
 851        ep_store.reject_prediction(
 852            EditPredictionId("test-2".into()),
 853            EditPredictionRejectReason::Canceled,
 854            true,
 855        );
 856    });
 857
 858    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
 859    cx.run_until_parked();
 860
 861    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
 862    respond_tx.send(()).unwrap();
 863
 864    // batched
 865    assert_eq!(reject_request.rejections.len(), 2);
 866    assert_eq!(
 867        reject_request.rejections[0],
 868        EditPredictionRejection {
 869            request_id: "test-1".to_string(),
 870            reason: EditPredictionRejectReason::Discarded,
 871            was_shown: false
 872        }
 873    );
 874    assert_eq!(
 875        reject_request.rejections[1],
 876        EditPredictionRejection {
 877            request_id: "test-2".to_string(),
 878            reason: EditPredictionRejectReason::Canceled,
 879            was_shown: true
 880        }
 881    );
 882
 883    // Reaching batch size limit sends without debounce
 884    ep_store.update(cx, |ep_store, _cx| {
 885        for i in 0..70 {
 886            ep_store.reject_prediction(
 887                EditPredictionId(format!("batch-{}", i).into()),
 888                EditPredictionRejectReason::Discarded,
 889                false,
 890            );
 891        }
 892    });
 893
 894    // First MAX/2 items are sent immediately
 895    cx.run_until_parked();
 896    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
 897    respond_tx.send(()).unwrap();
 898
 899    assert_eq!(reject_request.rejections.len(), 50);
 900    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
 901    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
 902
 903    // Remaining items are debounced with the next batch
 904    cx.executor().advance_clock(Duration::from_secs(15));
 905    cx.run_until_parked();
 906
 907    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
 908    respond_tx.send(()).unwrap();
 909
 910    assert_eq!(reject_request.rejections.len(), 20);
 911    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
 912    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
 913
 914    // Request failure
 915    ep_store.update(cx, |ep_store, _cx| {
 916        ep_store.reject_prediction(
 917            EditPredictionId("retry-1".into()),
 918            EditPredictionRejectReason::Discarded,
 919            false,
 920        );
 921    });
 922
 923    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
 924    cx.run_until_parked();
 925
 926    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
 927    assert_eq!(reject_request.rejections.len(), 1);
 928    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
 929    // Simulate failure
 930    drop(_respond_tx);
 931
 932    // Add another rejection
 933    ep_store.update(cx, |ep_store, _cx| {
 934        ep_store.reject_prediction(
 935            EditPredictionId("retry-2".into()),
 936            EditPredictionRejectReason::Discarded,
 937            false,
 938        );
 939    });
 940
 941    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
 942    cx.run_until_parked();
 943
 944    // Retry should include both the failed item and the new one
 945    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
 946    respond_tx.send(()).unwrap();
 947
 948    assert_eq!(reject_request.rejections.len(), 2);
 949    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
 950    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
 951}
 952
 953// Skipped until we start including diagnostics in prompt
 954// #[gpui::test]
 955// async fn test_request_diagnostics(cx: &mut TestAppContext) {
 956//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
 957//     let fs = FakeFs::new(cx.executor());
 958//     fs.insert_tree(
 959//         "/root",
 960//         json!({
 961//             "foo.md": "Hello!\nBye"
 962//         }),
 963//     )
 964//     .await;
 965//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 966
 967//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
 968//     let diagnostic = lsp::Diagnostic {
 969//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 970//         severity: Some(lsp::DiagnosticSeverity::ERROR),
 971//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
 972//         ..Default::default()
 973//     };
 974
 975//     project.update(cx, |project, cx| {
 976//         project.lsp_store().update(cx, |lsp_store, cx| {
 977//             // Create some diagnostics
 978//             lsp_store
 979//                 .update_diagnostics(
 980//                     LanguageServerId(0),
 981//                     lsp::PublishDiagnosticsParams {
 982//                         uri: path_to_buffer_uri.clone(),
 983//                         diagnostics: vec![diagnostic],
 984//                         version: None,
 985//                     },
 986//                     None,
 987//                     language::DiagnosticSourceKind::Pushed,
 988//                     &[],
 989//                     cx,
 990//                 )
 991//                 .unwrap();
 992//         });
 993//     });
 994
 995//     let buffer = project
 996//         .update(cx, |project, cx| {
 997//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 998//             project.open_buffer(path, cx)
 999//         })
1000//         .await
1001//         .unwrap();
1002
1003//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1004//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1005
1006//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1007//         ep_store.request_prediction(&project, &buffer, position, cx)
1008//     });
1009
1010//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1011
1012//     assert_eq!(request.diagnostic_groups.len(), 1);
1013//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1014//         .unwrap();
1015//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1016//     assert_eq!(
1017//         value,
1018//         json!({
1019//             "entries": [{
1020//                 "range": {
1021//                     "start": 8,
1022//                     "end": 10
1023//                 },
1024//                 "diagnostic": {
1025//                     "source": null,
1026//                     "code": null,
1027//                     "code_description": null,
1028//                     "severity": 1,
1029//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1030//                     "markdown": null,
1031//                     "group_id": 0,
1032//                     "is_primary": true,
1033//                     "is_disk_based": false,
1034//                     "is_unnecessary": false,
1035//                     "source_kind": "Pushed",
1036//                     "data": null,
1037//                     "underline": true
1038//                 }
1039//             }],
1040//             "primary_ix": 0
1041//         })
1042//     );
1043// }
1044
1045// Generate a model response that would apply the given diff to the active file.
1046fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
1047    let prompt = match &request.messages[0] {
1048        open_ai::RequestMessage::User {
1049            content: open_ai::MessageContent::Plain(content),
1050        } => content,
1051        _ => panic!("unexpected request {request:?}"),
1052    };
1053
1054    let open = "<editable_region>\n";
1055    let close = "</editable_region>";
1056    let cursor = "<|user_cursor|>";
1057
1058    let start_ix = open.len() + prompt.find(open).unwrap();
1059    let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
1060    let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
1061    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1062
1063    open_ai::Response {
1064        id: Uuid::new_v4().to_string(),
1065        object: "response".into(),
1066        created: 0,
1067        model: "model".into(),
1068        choices: vec![open_ai::Choice {
1069            index: 0,
1070            message: open_ai::RequestMessage::Assistant {
1071                content: Some(open_ai::MessageContent::Plain(new_excerpt)),
1072                tool_calls: vec![],
1073            },
1074            finish_reason: None,
1075        }],
1076        usage: Usage {
1077            prompt_tokens: 0,
1078            completion_tokens: 0,
1079            total_tokens: 0,
1080        },
1081    }
1082}
1083
1084fn prompt_from_request(request: &open_ai::Request) -> &str {
1085    assert_eq!(request.messages.len(), 1);
1086    let open_ai::RequestMessage::User {
1087        content: open_ai::MessageContent::Plain(content),
1088        ..
1089    } = &request.messages[0]
1090    else {
1091        panic!(
1092            "Request does not have single user message of type Plain. {:#?}",
1093            request
1094        );
1095    };
1096    content
1097}
1098
1099struct RequestChannels {
1100    predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
1101    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1102}
1103
1104fn init_test_with_fake_client(
1105    cx: &mut TestAppContext,
1106) -> (Entity<EditPredictionStore>, RequestChannels) {
1107    cx.update(move |cx| {
1108        let settings_store = SettingsStore::test(cx);
1109        cx.set_global(settings_store);
1110        zlog::init_test();
1111
1112        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1113        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1114
1115        let http_client = FakeHttpClient::create({
1116            move |req| {
1117                let uri = req.uri().path().to_string();
1118                let mut body = req.into_body();
1119                let predict_req_tx = predict_req_tx.clone();
1120                let reject_req_tx = reject_req_tx.clone();
1121                async move {
1122                    let resp = match uri.as_str() {
1123                        "/client/llm_tokens" => serde_json::to_string(&json!({
1124                            "token": "test"
1125                        }))
1126                        .unwrap(),
1127                        "/predict_edits/raw" => {
1128                            let mut buf = Vec::new();
1129                            body.read_to_end(&mut buf).await.ok();
1130                            let req = serde_json::from_slice(&buf).unwrap();
1131
1132                            let (res_tx, res_rx) = oneshot::channel();
1133                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1134                            serde_json::to_string(&res_rx.await?).unwrap()
1135                        }
1136                        "/predict_edits/reject" => {
1137                            let mut buf = Vec::new();
1138                            body.read_to_end(&mut buf).await.ok();
1139                            let req = serde_json::from_slice(&buf).unwrap();
1140
1141                            let (res_tx, res_rx) = oneshot::channel();
1142                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1143                            serde_json::to_string(&res_rx.await?).unwrap()
1144                        }
1145                        _ => {
1146                            panic!("Unexpected path: {}", uri)
1147                        }
1148                    };
1149
1150                    Ok(Response::builder().body(resp.into()).unwrap())
1151                }
1152            }
1153        });
1154
1155        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1156        client.cloud_client().set_credentials(1, "test".into());
1157
1158        language_model::init(client.clone(), cx);
1159
1160        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1161        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1162
1163        (
1164            ep_store,
1165            RequestChannels {
1166                predict: predict_req_rx,
1167                reject: reject_req_rx,
1168            },
1169        )
1170    })
1171}
1172
1173const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1174
1175#[gpui::test]
1176async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1177    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1178    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1179        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1180    });
1181
1182    let edit_preview = cx
1183        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1184        .await;
1185
1186    let prediction = EditPrediction {
1187        edits,
1188        edit_preview,
1189        buffer: buffer.clone(),
1190        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1191        id: EditPredictionId("the-id".into()),
1192        inputs: ZetaPromptInput {
1193            events: Default::default(),
1194            related_files: Default::default(),
1195            cursor_path: Path::new("").into(),
1196            cursor_excerpt: "".into(),
1197            editable_range_in_excerpt: 0..0,
1198            cursor_offset_in_excerpt: 0,
1199        },
1200        buffer_snapshotted_at: Instant::now(),
1201        response_received_at: Instant::now(),
1202    };
1203
1204    cx.update(|cx| {
1205        assert_eq!(
1206            from_completion_edits(
1207                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1208                &buffer,
1209                cx
1210            ),
1211            vec![(2..5, "REM".into()), (9..11, "".into())]
1212        );
1213
1214        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1215        assert_eq!(
1216            from_completion_edits(
1217                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1218                &buffer,
1219                cx
1220            ),
1221            vec![(2..2, "REM".into()), (6..8, "".into())]
1222        );
1223
1224        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1225        assert_eq!(
1226            from_completion_edits(
1227                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1228                &buffer,
1229                cx
1230            ),
1231            vec![(2..5, "REM".into()), (9..11, "".into())]
1232        );
1233
1234        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1235        assert_eq!(
1236            from_completion_edits(
1237                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1238                &buffer,
1239                cx
1240            ),
1241            vec![(3..3, "EM".into()), (7..9, "".into())]
1242        );
1243
1244        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1245        assert_eq!(
1246            from_completion_edits(
1247                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1248                &buffer,
1249                cx
1250            ),
1251            vec![(4..4, "M".into()), (8..10, "".into())]
1252        );
1253
1254        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1255        assert_eq!(
1256            from_completion_edits(
1257                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1258                &buffer,
1259                cx
1260            ),
1261            vec![(9..11, "".into())]
1262        );
1263
1264        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1265        assert_eq!(
1266            from_completion_edits(
1267                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1268                &buffer,
1269                cx
1270            ),
1271            vec![(4..4, "M".into()), (8..10, "".into())]
1272        );
1273
1274        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1275        assert_eq!(
1276            from_completion_edits(
1277                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1278                &buffer,
1279                cx
1280            ),
1281            vec![(4..4, "M".into())]
1282        );
1283
1284        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1285        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1286    })
1287}
1288
1289#[gpui::test]
1290async fn test_clean_up_diff(cx: &mut TestAppContext) {
1291    init_test(cx);
1292
1293    assert_eq!(
1294        apply_edit_prediction(
1295            indoc! {"
1296                    fn main() {
1297                        let word_1 = \"lorem\";
1298                        let range = word.len()..word.len();
1299                    }
1300                "},
1301            indoc! {"
1302                    <|editable_region_start|>
1303                    fn main() {
1304                        let word_1 = \"lorem\";
1305                        let range = word_1.len()..word_1.len();
1306                    }
1307
1308                    <|editable_region_end|>
1309                "},
1310            cx,
1311        )
1312        .await,
1313        indoc! {"
1314                fn main() {
1315                    let word_1 = \"lorem\";
1316                    let range = word_1.len()..word_1.len();
1317                }
1318            "},
1319    );
1320
1321    assert_eq!(
1322        apply_edit_prediction(
1323            indoc! {"
1324                    fn main() {
1325                        let story = \"the quick\"
1326                    }
1327                "},
1328            indoc! {"
1329                    <|editable_region_start|>
1330                    fn main() {
1331                        let story = \"the quick brown fox jumps over the lazy dog\";
1332                    }
1333
1334                    <|editable_region_end|>
1335                "},
1336            cx,
1337        )
1338        .await,
1339        indoc! {"
1340                fn main() {
1341                    let story = \"the quick brown fox jumps over the lazy dog\";
1342                }
1343            "},
1344    );
1345}
1346
1347#[gpui::test]
1348async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1349    init_test(cx);
1350
1351    let buffer_content = "lorem\n";
1352    let completion_response = indoc! {"
1353            ```animals.js
1354            <|start_of_file|>
1355            <|editable_region_start|>
1356            lorem
1357            ipsum
1358            <|editable_region_end|>
1359            ```"};
1360
1361    assert_eq!(
1362        apply_edit_prediction(buffer_content, completion_response, cx).await,
1363        "lorem\nipsum"
1364    );
1365}
1366
1367#[gpui::test]
1368async fn test_can_collect_data(cx: &mut TestAppContext) {
1369    init_test(cx);
1370
1371    let fs = project::FakeFs::new(cx.executor());
1372    fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1373        .await;
1374
1375    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1376    let buffer = project
1377        .update(cx, |project, cx| {
1378            project.open_local_buffer(path!("/project/src/main.rs"), cx)
1379        })
1380        .await
1381        .unwrap();
1382
1383    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1384    ep_store.update(cx, |ep_store, _cx| {
1385        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1386    });
1387
1388    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1389    assert_eq!(
1390        captured_request.lock().clone().unwrap().can_collect_data,
1391        true
1392    );
1393
1394    ep_store.update(cx, |ep_store, _cx| {
1395        ep_store.data_collection_choice = DataCollectionChoice::Disabled
1396    });
1397
1398    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1399    assert_eq!(
1400        captured_request.lock().clone().unwrap().can_collect_data,
1401        false
1402    );
1403}
1404
1405#[gpui::test]
1406async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1407    init_test(cx);
1408
1409    let fs = project::FakeFs::new(cx.executor());
1410    let project = Project::test(fs.clone(), [], cx).await;
1411
1412    let buffer = cx.new(|_cx| {
1413        Buffer::remote(
1414            language::BufferId::new(1).unwrap(),
1415            ReplicaId::new(1),
1416            language::Capability::ReadWrite,
1417            "fn main() {\n    println!(\"Hello\");\n}",
1418        )
1419    });
1420
1421    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1422    ep_store.update(cx, |ep_store, _cx| {
1423        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1424    });
1425
1426    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1427    assert_eq!(
1428        captured_request.lock().clone().unwrap().can_collect_data,
1429        false
1430    );
1431}
1432
1433#[gpui::test]
1434async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1435    init_test(cx);
1436
1437    let fs = project::FakeFs::new(cx.executor());
1438    fs.insert_tree(
1439        path!("/project"),
1440        json!({
1441            "LICENSE": BSD_0_TXT,
1442            ".env": "SECRET_KEY=secret"
1443        }),
1444    )
1445    .await;
1446
1447    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1448    let buffer = project
1449        .update(cx, |project, cx| {
1450            project.open_local_buffer("/project/.env", cx)
1451        })
1452        .await
1453        .unwrap();
1454
1455    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1456    ep_store.update(cx, |ep_store, _cx| {
1457        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1458    });
1459
1460    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1461    assert_eq!(
1462        captured_request.lock().clone().unwrap().can_collect_data,
1463        false
1464    );
1465}
1466
1467#[gpui::test]
1468async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1469    init_test(cx);
1470
1471    let fs = project::FakeFs::new(cx.executor());
1472    let project = Project::test(fs.clone(), [], cx).await;
1473    let buffer = cx.new(|cx| Buffer::local("", cx));
1474
1475    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1476    ep_store.update(cx, |ep_store, _cx| {
1477        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1478    });
1479
1480    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1481    assert_eq!(
1482        captured_request.lock().clone().unwrap().can_collect_data,
1483        false
1484    );
1485}
1486
1487#[gpui::test]
1488async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1489    init_test(cx);
1490
1491    let fs = project::FakeFs::new(cx.executor());
1492    fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1493        .await;
1494
1495    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1496    let buffer = project
1497        .update(cx, |project, cx| {
1498            project.open_local_buffer("/project/main.rs", cx)
1499        })
1500        .await
1501        .unwrap();
1502
1503    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1504    ep_store.update(cx, |ep_store, _cx| {
1505        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1506    });
1507
1508    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1509    assert_eq!(
1510        captured_request.lock().clone().unwrap().can_collect_data,
1511        false
1512    );
1513}
1514
1515#[gpui::test]
1516async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1517    init_test(cx);
1518
1519    let fs = project::FakeFs::new(cx.executor());
1520    fs.insert_tree(
1521        path!("/open_source_worktree"),
1522        json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1523    )
1524    .await;
1525    fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1526        .await;
1527
1528    let project = Project::test(
1529        fs.clone(),
1530        [
1531            path!("/open_source_worktree").as_ref(),
1532            path!("/closed_source_worktree").as_ref(),
1533        ],
1534        cx,
1535    )
1536    .await;
1537    let buffer = project
1538        .update(cx, |project, cx| {
1539            project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1540        })
1541        .await
1542        .unwrap();
1543
1544    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1545    ep_store.update(cx, |ep_store, _cx| {
1546        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1547    });
1548
1549    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1550    assert_eq!(
1551        captured_request.lock().clone().unwrap().can_collect_data,
1552        true
1553    );
1554
1555    let closed_source_file = project
1556        .update(cx, |project, cx| {
1557            let worktree2 = project
1558                .worktree_for_root_name("closed_source_worktree", cx)
1559                .unwrap();
1560            worktree2.update(cx, |worktree2, cx| {
1561                worktree2.load_file(rel_path("main.rs"), cx)
1562            })
1563        })
1564        .await
1565        .unwrap()
1566        .file;
1567
1568    buffer.update(cx, |buffer, cx| {
1569        buffer.file_updated(closed_source_file, cx);
1570    });
1571
1572    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1573    assert_eq!(
1574        captured_request.lock().clone().unwrap().can_collect_data,
1575        false
1576    );
1577}
1578
1579#[gpui::test]
1580async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1581    init_test(cx);
1582
1583    let fs = project::FakeFs::new(cx.executor());
1584    fs.insert_tree(
1585        path!("/worktree1"),
1586        json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1587    )
1588    .await;
1589    fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1590        .await;
1591
1592    let project = Project::test(
1593        fs.clone(),
1594        [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1595        cx,
1596    )
1597    .await;
1598    let buffer = project
1599        .update(cx, |project, cx| {
1600            project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1601        })
1602        .await
1603        .unwrap();
1604    let private_buffer = project
1605        .update(cx, |project, cx| {
1606            project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1607        })
1608        .await
1609        .unwrap();
1610
1611    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1612    ep_store.update(cx, |ep_store, _cx| {
1613        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1614    });
1615
1616    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1617    assert_eq!(
1618        captured_request.lock().clone().unwrap().can_collect_data,
1619        true
1620    );
1621
1622    // this has a side effect of registering the buffer to watch for edits
1623    run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1624    assert_eq!(
1625        captured_request.lock().clone().unwrap().can_collect_data,
1626        false
1627    );
1628
1629    private_buffer.update(cx, |private_buffer, cx| {
1630        private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1631    });
1632
1633    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1634    assert_eq!(
1635        captured_request.lock().clone().unwrap().can_collect_data,
1636        false
1637    );
1638
1639    // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1640    // included
1641    buffer.update(cx, |buffer, cx| {
1642        buffer.edit(
1643            [(
1644                0..0,
1645                " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1646            )],
1647            None,
1648            cx,
1649        );
1650    });
1651
1652    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1653    assert_eq!(
1654        captured_request.lock().clone().unwrap().can_collect_data,
1655        true
1656    );
1657}
1658
1659fn init_test(cx: &mut TestAppContext) {
1660    cx.update(|cx| {
1661        let settings_store = SettingsStore::test(cx);
1662        cx.set_global(settings_store);
1663    });
1664}
1665
1666async fn apply_edit_prediction(
1667    buffer_content: &str,
1668    completion_response: &str,
1669    cx: &mut TestAppContext,
1670) -> String {
1671    let fs = project::FakeFs::new(cx.executor());
1672    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1673    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1674    let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1675    *response.lock() = completion_response.to_string();
1676    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1677    buffer.update(cx, |buffer, cx| {
1678        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1679    });
1680    buffer.read_with(cx, |buffer, _| buffer.text())
1681}
1682
1683async fn run_edit_prediction(
1684    buffer: &Entity<Buffer>,
1685    project: &Entity<Project>,
1686    ep_store: &Entity<EditPredictionStore>,
1687    cx: &mut TestAppContext,
1688) -> EditPrediction {
1689    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1690    ep_store.update(cx, |ep_store, cx| {
1691        ep_store.register_buffer(buffer, &project, cx)
1692    });
1693    cx.background_executor.run_until_parked();
1694    let prediction_task = ep_store.update(cx, |ep_store, cx| {
1695        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
1696    });
1697    prediction_task.await.unwrap().unwrap().prediction.unwrap()
1698}
1699
1700async fn make_test_ep_store(
1701    project: &Entity<Project>,
1702    cx: &mut TestAppContext,
1703) -> (
1704    Entity<EditPredictionStore>,
1705    Arc<Mutex<Option<PredictEditsBody>>>,
1706    Arc<Mutex<String>>,
1707) {
1708    let default_response = indoc! {"
1709            ```main.rs
1710            <|start_of_file|>
1711            <|editable_region_start|>
1712            hello world
1713            <|editable_region_end|>
1714            ```"
1715    };
1716    let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
1717    let completion_response: Arc<Mutex<String>> =
1718        Arc::new(Mutex::new(default_response.to_string()));
1719    let http_client = FakeHttpClient::create({
1720        let captured_request = captured_request.clone();
1721        let completion_response = completion_response.clone();
1722        let mut next_request_id = 0;
1723        move |req| {
1724            let captured_request = captured_request.clone();
1725            let completion_response = completion_response.clone();
1726            async move {
1727                match (req.method(), req.uri().path()) {
1728                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1729                        .status(200)
1730                        .body(
1731                            serde_json::to_string(&CreateLlmTokenResponse {
1732                                token: LlmToken("the-llm-token".to_string()),
1733                            })
1734                            .unwrap()
1735                            .into(),
1736                        )
1737                        .unwrap()),
1738                    (&Method::POST, "/predict_edits/v2") => {
1739                        let mut request_body = String::new();
1740                        req.into_body().read_to_string(&mut request_body).await?;
1741                        *captured_request.lock() =
1742                            Some(serde_json::from_str(&request_body).unwrap());
1743                        next_request_id += 1;
1744                        Ok(http_client::Response::builder()
1745                            .status(200)
1746                            .body(
1747                                serde_json::to_string(&PredictEditsResponse {
1748                                    request_id: format!("request-{next_request_id}"),
1749                                    output_excerpt: completion_response.lock().clone(),
1750                                })
1751                                .unwrap()
1752                                .into(),
1753                            )
1754                            .unwrap())
1755                    }
1756                    _ => Ok(http_client::Response::builder()
1757                        .status(404)
1758                        .body("Not Found".into())
1759                        .unwrap()),
1760                }
1761            }
1762        }
1763    });
1764
1765    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1766    cx.update(|cx| {
1767        RefreshLlmTokenListener::register(client.clone(), cx);
1768    });
1769    let _server = FakeServer::for_client(42, &client, cx).await;
1770
1771    let ep_store = cx.new(|cx| {
1772        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
1773        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
1774
1775        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
1776        for worktree in worktrees {
1777            let worktree_id = worktree.read(cx).id();
1778            ep_store
1779                .get_or_init_project(project, cx)
1780                .license_detection_watchers
1781                .entry(worktree_id)
1782                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
1783        }
1784
1785        ep_store
1786    });
1787
1788    (ep_store, captured_request, completion_response)
1789}
1790
1791fn to_completion_edits(
1792    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
1793    buffer: &Entity<Buffer>,
1794    cx: &App,
1795) -> Vec<(Range<Anchor>, Arc<str>)> {
1796    let buffer = buffer.read(cx);
1797    iterator
1798        .into_iter()
1799        .map(|(range, text)| {
1800            (
1801                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1802                text,
1803            )
1804        })
1805        .collect()
1806}
1807
1808fn from_completion_edits(
1809    editor_edits: &[(Range<Anchor>, Arc<str>)],
1810    buffer: &Entity<Buffer>,
1811    cx: &App,
1812) -> Vec<(Range<usize>, Arc<str>)> {
1813    let buffer = buffer.read(cx);
1814    editor_edits
1815        .iter()
1816        .map(|(range, text)| {
1817            (
1818                range.start.to_offset(buffer)..range.end.to_offset(buffer),
1819                text.clone(),
1820            )
1821        })
1822        .collect()
1823}
1824
1825#[ctor::ctor]
1826fn init_logger() {
1827    zlog::init_test();
1828}