edit_prediction_tests.rs

   1use super::*;
   2use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
   3use client::{UserStore, test::FakeServer};
   4use clock::{FakeSystemClock, ReplicaId};
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
   8    RejectEditPredictionsBody,
   9};
  10use futures::{
  11    AsyncReadExt, StreamExt,
  12    channel::{mpsc, oneshot},
  13};
  14use gpui::{
  15    Entity, TestAppContext,
  16    http_client::{FakeHttpClient, Response},
  17};
  18use indoc::indoc;
  19use language::{Point, ToOffset as _};
  20use lsp::LanguageServerId;
  21use open_ai::Usage;
  22use parking_lot::Mutex;
  23use pretty_assertions::{assert_eq, assert_matches};
  24use project::{FakeFs, Project};
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{path::Path, sync::Arc, time::Duration};
  28use util::{path, rel_path::rel_path};
  29use uuid::Uuid;
  30use zeta_prompt::ZetaPromptInput;
  31
  32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  33
  34#[gpui::test]
  35async fn test_current_state(cx: &mut TestAppContext) {
  36    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  37    let fs = FakeFs::new(cx.executor());
  38    fs.insert_tree(
  39        "/root",
  40        json!({
  41            "1.txt": "Hello!\nHow\nBye\n",
  42            "2.txt": "Hola!\nComo\nAdios\n"
  43        }),
  44    )
  45    .await;
  46    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  47
  48    let buffer1 = project
  49        .update(cx, |project, cx| {
  50            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  51            project.set_active_path(Some(path.clone()), cx);
  52            project.open_buffer(path, cx)
  53        })
  54        .await
  55        .unwrap();
  56    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  57    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  58
  59    ep_store.update(cx, |ep_store, cx| {
  60        ep_store.register_project(&project, cx);
  61        ep_store.register_buffer(&buffer1, &project, cx);
  62    });
  63
  64    // Prediction for current file
  65
  66    ep_store.update(cx, |ep_store, cx| {
  67        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  68    });
  69    let (request, respond_tx) = requests.predict.next().await.unwrap();
  70
  71    respond_tx
  72        .send(model_response(
  73            request,
  74            indoc! {r"
  75                --- a/root/1.txt
  76                +++ b/root/1.txt
  77                @@ ... @@
  78                 Hello!
  79                -How
  80                +How are you?
  81                 Bye
  82            "},
  83        ))
  84        .unwrap();
  85
  86    cx.run_until_parked();
  87
  88    ep_store.update(cx, |ep_store, cx| {
  89        let prediction = ep_store
  90            .prediction_at(&buffer1, None, &project, cx)
  91            .unwrap();
  92        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  93    });
  94
  95    ep_store.update(cx, |ep_store, _cx| {
  96        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
  97    });
  98
  99    // Prediction for diagnostic in another file
 100
 101    let diagnostic = lsp::Diagnostic {
 102        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 103        severity: Some(lsp::DiagnosticSeverity::ERROR),
 104        message: "Sentence is incomplete".to_string(),
 105        ..Default::default()
 106    };
 107
 108    project.update(cx, |project, cx| {
 109        project.lsp_store().update(cx, |lsp_store, cx| {
 110            lsp_store
 111                .update_diagnostics(
 112                    LanguageServerId(0),
 113                    lsp::PublishDiagnosticsParams {
 114                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 115                        diagnostics: vec![diagnostic],
 116                        version: None,
 117                    },
 118                    None,
 119                    language::DiagnosticSourceKind::Pushed,
 120                    &[],
 121                    cx,
 122                )
 123                .unwrap();
 124        });
 125    });
 126
 127    let (request, respond_tx) = requests.predict.next().await.unwrap();
 128    respond_tx
 129        .send(model_response(
 130            request,
 131            indoc! {r#"
 132                --- a/root/2.txt
 133                +++ b/root/2.txt
 134                @@ ... @@
 135                 Hola!
 136                -Como
 137                +Como estas?
 138                 Adios
 139            "#},
 140        ))
 141        .unwrap();
 142    cx.run_until_parked();
 143
 144    ep_store.update(cx, |ep_store, cx| {
 145        let prediction = ep_store
 146            .prediction_at(&buffer1, None, &project, cx)
 147            .unwrap();
 148        assert_matches!(
 149            prediction,
 150            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 151        );
 152    });
 153
 154    let buffer2 = project
 155        .update(cx, |project, cx| {
 156            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 157            project.open_buffer(path, cx)
 158        })
 159        .await
 160        .unwrap();
 161
 162    ep_store.update(cx, |ep_store, cx| {
 163        let prediction = ep_store
 164            .prediction_at(&buffer2, None, &project, cx)
 165            .unwrap();
 166        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 167    });
 168}
 169
 170#[gpui::test]
 171async fn test_simple_request(cx: &mut TestAppContext) {
 172    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 173    let fs = FakeFs::new(cx.executor());
 174    fs.insert_tree(
 175        "/root",
 176        json!({
 177            "foo.md":  "Hello!\nHow\nBye\n"
 178        }),
 179    )
 180    .await;
 181    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 182
 183    let buffer = project
 184        .update(cx, |project, cx| {
 185            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 186            project.open_buffer(path, cx)
 187        })
 188        .await
 189        .unwrap();
 190    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 191    let position = snapshot.anchor_before(language::Point::new(1, 3));
 192
 193    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 194        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 195    });
 196
 197    let (request, respond_tx) = requests.predict.next().await.unwrap();
 198
 199    // TODO Put back when we have a structured request again
 200    // assert_eq!(
 201    //     request.excerpt_path.as_ref(),
 202    //     Path::new(path!("root/foo.md"))
 203    // );
 204    // assert_eq!(
 205    //     request.cursor_point,
 206    //     Point {
 207    //         line: Line(1),
 208    //         column: 3
 209    //     }
 210    // );
 211
 212    respond_tx
 213        .send(model_response(
 214            request,
 215            indoc! { r"
 216                --- a/root/foo.md
 217                +++ b/root/foo.md
 218                @@ ... @@
 219                 Hello!
 220                -How
 221                +How are you?
 222                 Bye
 223            "},
 224        ))
 225        .unwrap();
 226
 227    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 228
 229    assert_eq!(prediction.edits.len(), 1);
 230    assert_eq!(
 231        prediction.edits[0].0.to_point(&snapshot).start,
 232        language::Point::new(1, 3)
 233    );
 234    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 235}
 236
 237#[gpui::test]
 238async fn test_request_events(cx: &mut TestAppContext) {
 239    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 240    let fs = FakeFs::new(cx.executor());
 241    fs.insert_tree(
 242        "/root",
 243        json!({
 244            "foo.md": "Hello!\n\nBye\n"
 245        }),
 246    )
 247    .await;
 248    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 249
 250    let buffer = project
 251        .update(cx, |project, cx| {
 252            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 253            project.open_buffer(path, cx)
 254        })
 255        .await
 256        .unwrap();
 257
 258    ep_store.update(cx, |ep_store, cx| {
 259        ep_store.register_buffer(&buffer, &project, cx);
 260    });
 261
 262    buffer.update(cx, |buffer, cx| {
 263        buffer.edit(vec![(7..7, "How")], None, cx);
 264    });
 265
 266    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 267    let position = snapshot.anchor_before(language::Point::new(1, 3));
 268
 269    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 270        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 271    });
 272
 273    let (request, respond_tx) = requests.predict.next().await.unwrap();
 274
 275    let prompt = prompt_from_request(&request);
 276    assert!(
 277        prompt.contains(indoc! {"
 278        --- a/root/foo.md
 279        +++ b/root/foo.md
 280        @@ -1,3 +1,3 @@
 281         Hello!
 282        -
 283        +How
 284         Bye
 285    "}),
 286        "{prompt}"
 287    );
 288
 289    respond_tx
 290        .send(model_response(
 291            request,
 292            indoc! {r#"
 293                --- a/root/foo.md
 294                +++ b/root/foo.md
 295                @@ ... @@
 296                 Hello!
 297                -How
 298                +How are you?
 299                 Bye
 300        "#},
 301        ))
 302        .unwrap();
 303
 304    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 305
 306    assert_eq!(prediction.edits.len(), 1);
 307    assert_eq!(
 308        prediction.edits[0].0.to_point(&snapshot).start,
 309        language::Point::new(1, 3)
 310    );
 311    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 312}
 313
 314#[gpui::test]
 315async fn test_empty_prediction(cx: &mut TestAppContext) {
 316    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 317    let fs = FakeFs::new(cx.executor());
 318    fs.insert_tree(
 319        "/root",
 320        json!({
 321            "foo.md":  "Hello!\nHow\nBye\n"
 322        }),
 323    )
 324    .await;
 325    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 326
 327    let buffer = project
 328        .update(cx, |project, cx| {
 329            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 330            project.open_buffer(path, cx)
 331        })
 332        .await
 333        .unwrap();
 334    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 335    let position = snapshot.anchor_before(language::Point::new(1, 3));
 336
 337    ep_store.update(cx, |ep_store, cx| {
 338        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 339    });
 340
 341    let (request, respond_tx) = requests.predict.next().await.unwrap();
 342    let response = model_response(request, "");
 343    let id = response.id.clone();
 344    respond_tx.send(response).unwrap();
 345
 346    cx.run_until_parked();
 347
 348    ep_store.update(cx, |ep_store, cx| {
 349        assert!(
 350            ep_store
 351                .prediction_at(&buffer, None, &project, cx)
 352                .is_none()
 353        );
 354    });
 355
 356    // prediction is reported as rejected
 357    let (reject_request, _) = requests.reject.next().await.unwrap();
 358
 359    assert_eq!(
 360        &reject_request.rejections,
 361        &[EditPredictionRejection {
 362            request_id: id,
 363            reason: EditPredictionRejectReason::Empty,
 364            was_shown: false
 365        }]
 366    );
 367}
 368
 369#[gpui::test]
 370async fn test_interpolated_empty(cx: &mut TestAppContext) {
 371    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 372    let fs = FakeFs::new(cx.executor());
 373    fs.insert_tree(
 374        "/root",
 375        json!({
 376            "foo.md":  "Hello!\nHow\nBye\n"
 377        }),
 378    )
 379    .await;
 380    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 381
 382    let buffer = project
 383        .update(cx, |project, cx| {
 384            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 385            project.open_buffer(path, cx)
 386        })
 387        .await
 388        .unwrap();
 389    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 390    let position = snapshot.anchor_before(language::Point::new(1, 3));
 391
 392    ep_store.update(cx, |ep_store, cx| {
 393        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 394    });
 395
 396    let (request, respond_tx) = requests.predict.next().await.unwrap();
 397
 398    buffer.update(cx, |buffer, cx| {
 399        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 400    });
 401
 402    let response = model_response(request, SIMPLE_DIFF);
 403    let id = response.id.clone();
 404    respond_tx.send(response).unwrap();
 405
 406    cx.run_until_parked();
 407
 408    ep_store.update(cx, |ep_store, cx| {
 409        assert!(
 410            ep_store
 411                .prediction_at(&buffer, None, &project, cx)
 412                .is_none()
 413        );
 414    });
 415
 416    // prediction is reported as rejected
 417    let (reject_request, _) = requests.reject.next().await.unwrap();
 418
 419    assert_eq!(
 420        &reject_request.rejections,
 421        &[EditPredictionRejection {
 422            request_id: id,
 423            reason: EditPredictionRejectReason::InterpolatedEmpty,
 424            was_shown: false
 425        }]
 426    );
 427}
 428
 429const SIMPLE_DIFF: &str = indoc! { r"
 430    --- a/root/foo.md
 431    +++ b/root/foo.md
 432    @@ ... @@
 433     Hello!
 434    -How
 435    +How are you?
 436     Bye
 437"};
 438
 439#[gpui::test]
 440async fn test_replace_current(cx: &mut TestAppContext) {
 441    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 442    let fs = FakeFs::new(cx.executor());
 443    fs.insert_tree(
 444        "/root",
 445        json!({
 446            "foo.md":  "Hello!\nHow\nBye\n"
 447        }),
 448    )
 449    .await;
 450    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 451
 452    let buffer = project
 453        .update(cx, |project, cx| {
 454            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 455            project.open_buffer(path, cx)
 456        })
 457        .await
 458        .unwrap();
 459    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 460    let position = snapshot.anchor_before(language::Point::new(1, 3));
 461
 462    ep_store.update(cx, |ep_store, cx| {
 463        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 464    });
 465
 466    let (request, respond_tx) = requests.predict.next().await.unwrap();
 467    let first_response = model_response(request, SIMPLE_DIFF);
 468    let first_id = first_response.id.clone();
 469    respond_tx.send(first_response).unwrap();
 470
 471    cx.run_until_parked();
 472
 473    ep_store.update(cx, |ep_store, cx| {
 474        assert_eq!(
 475            ep_store
 476                .prediction_at(&buffer, None, &project, cx)
 477                .unwrap()
 478                .id
 479                .0,
 480            first_id
 481        );
 482    });
 483
 484    // a second request is triggered
 485    ep_store.update(cx, |ep_store, cx| {
 486        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 487    });
 488
 489    let (request, respond_tx) = requests.predict.next().await.unwrap();
 490    let second_response = model_response(request, SIMPLE_DIFF);
 491    let second_id = second_response.id.clone();
 492    respond_tx.send(second_response).unwrap();
 493
 494    cx.run_until_parked();
 495
 496    ep_store.update(cx, |ep_store, cx| {
 497        // second replaces first
 498        assert_eq!(
 499            ep_store
 500                .prediction_at(&buffer, None, &project, cx)
 501                .unwrap()
 502                .id
 503                .0,
 504            second_id
 505        );
 506    });
 507
 508    // first is reported as replaced
 509    let (reject_request, _) = requests.reject.next().await.unwrap();
 510
 511    assert_eq!(
 512        &reject_request.rejections,
 513        &[EditPredictionRejection {
 514            request_id: first_id,
 515            reason: EditPredictionRejectReason::Replaced,
 516            was_shown: false
 517        }]
 518    );
 519}
 520
 521#[gpui::test]
 522async fn test_current_preferred(cx: &mut TestAppContext) {
 523    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 524    let fs = FakeFs::new(cx.executor());
 525    fs.insert_tree(
 526        "/root",
 527        json!({
 528            "foo.md":  "Hello!\nHow\nBye\n"
 529        }),
 530    )
 531    .await;
 532    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 533
 534    let buffer = project
 535        .update(cx, |project, cx| {
 536            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 537            project.open_buffer(path, cx)
 538        })
 539        .await
 540        .unwrap();
 541    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 542    let position = snapshot.anchor_before(language::Point::new(1, 3));
 543
 544    ep_store.update(cx, |ep_store, cx| {
 545        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 546    });
 547
 548    let (request, respond_tx) = requests.predict.next().await.unwrap();
 549    let first_response = model_response(request, SIMPLE_DIFF);
 550    let first_id = first_response.id.clone();
 551    respond_tx.send(first_response).unwrap();
 552
 553    cx.run_until_parked();
 554
 555    ep_store.update(cx, |ep_store, cx| {
 556        assert_eq!(
 557            ep_store
 558                .prediction_at(&buffer, None, &project, cx)
 559                .unwrap()
 560                .id
 561                .0,
 562            first_id
 563        );
 564    });
 565
 566    // a second request is triggered
 567    ep_store.update(cx, |ep_store, cx| {
 568        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 569    });
 570
 571    let (request, respond_tx) = requests.predict.next().await.unwrap();
 572    // worse than current prediction
 573    let second_response = model_response(
 574        request,
 575        indoc! { r"
 576            --- a/root/foo.md
 577            +++ b/root/foo.md
 578            @@ ... @@
 579             Hello!
 580            -How
 581            +How are
 582             Bye
 583        "},
 584    );
 585    let second_id = second_response.id.clone();
 586    respond_tx.send(second_response).unwrap();
 587
 588    cx.run_until_parked();
 589
 590    ep_store.update(cx, |ep_store, cx| {
 591        // first is preferred over second
 592        assert_eq!(
 593            ep_store
 594                .prediction_at(&buffer, None, &project, cx)
 595                .unwrap()
 596                .id
 597                .0,
 598            first_id
 599        );
 600    });
 601
 602    // second is reported as rejected
 603    let (reject_request, _) = requests.reject.next().await.unwrap();
 604
 605    assert_eq!(
 606        &reject_request.rejections,
 607        &[EditPredictionRejection {
 608            request_id: second_id,
 609            reason: EditPredictionRejectReason::CurrentPreferred,
 610            was_shown: false
 611        }]
 612    );
 613}
 614
 615#[gpui::test]
 616async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
 617    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 618    let fs = FakeFs::new(cx.executor());
 619    fs.insert_tree(
 620        "/root",
 621        json!({
 622            "foo.md":  "Hello!\nHow\nBye\n"
 623        }),
 624    )
 625    .await;
 626    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 627
 628    let buffer = project
 629        .update(cx, |project, cx| {
 630            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 631            project.open_buffer(path, cx)
 632        })
 633        .await
 634        .unwrap();
 635    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 636    let position = snapshot.anchor_before(language::Point::new(1, 3));
 637
 638    // start two refresh tasks
 639    ep_store.update(cx, |ep_store, cx| {
 640        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 641    });
 642
 643    let (request1, respond_first) = requests.predict.next().await.unwrap();
 644
 645    ep_store.update(cx, |ep_store, cx| {
 646        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 647    });
 648
 649    let (request, respond_second) = requests.predict.next().await.unwrap();
 650
 651    // wait for throttle
 652    cx.run_until_parked();
 653
 654    // second responds first
 655    let second_response = model_response(request, SIMPLE_DIFF);
 656    let second_id = second_response.id.clone();
 657    respond_second.send(second_response).unwrap();
 658
 659    cx.run_until_parked();
 660
 661    ep_store.update(cx, |ep_store, cx| {
 662        // current prediction is second
 663        assert_eq!(
 664            ep_store
 665                .prediction_at(&buffer, None, &project, cx)
 666                .unwrap()
 667                .id
 668                .0,
 669            second_id
 670        );
 671    });
 672
 673    let first_response = model_response(request1, SIMPLE_DIFF);
 674    let first_id = first_response.id.clone();
 675    respond_first.send(first_response).unwrap();
 676
 677    cx.run_until_parked();
 678
 679    ep_store.update(cx, |ep_store, cx| {
 680        // current prediction is still second, since first was cancelled
 681        assert_eq!(
 682            ep_store
 683                .prediction_at(&buffer, None, &project, cx)
 684                .unwrap()
 685                .id
 686                .0,
 687            second_id
 688        );
 689    });
 690
 691    // first is reported as rejected
 692    let (reject_request, _) = requests.reject.next().await.unwrap();
 693
 694    cx.run_until_parked();
 695
 696    assert_eq!(
 697        &reject_request.rejections,
 698        &[EditPredictionRejection {
 699            request_id: first_id,
 700            reason: EditPredictionRejectReason::Canceled,
 701            was_shown: false
 702        }]
 703    );
 704}
 705
 706#[gpui::test]
 707async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
 708    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 709    let fs = FakeFs::new(cx.executor());
 710    fs.insert_tree(
 711        "/root",
 712        json!({
 713            "foo.md":  "Hello!\nHow\nBye\n"
 714        }),
 715    )
 716    .await;
 717    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 718
 719    let buffer = project
 720        .update(cx, |project, cx| {
 721            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 722            project.open_buffer(path, cx)
 723        })
 724        .await
 725        .unwrap();
 726    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 727    let position = snapshot.anchor_before(language::Point::new(1, 3));
 728
 729    // start two refresh tasks
 730    ep_store.update(cx, |ep_store, cx| {
 731        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 732    });
 733
 734    let (request1, respond_first) = requests.predict.next().await.unwrap();
 735
 736    ep_store.update(cx, |ep_store, cx| {
 737        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 738    });
 739
 740    let (request2, respond_second) = requests.predict.next().await.unwrap();
 741
 742    // wait for throttle, so requests are sent
 743    cx.run_until_parked();
 744
 745    ep_store.update(cx, |ep_store, cx| {
 746        // start a third request
 747        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 748
 749        // 2 are pending, so 2nd is cancelled
 750        assert_eq!(
 751            ep_store
 752                .get_or_init_project(&project, cx)
 753                .cancelled_predictions
 754                .iter()
 755                .copied()
 756                .collect::<Vec<_>>(),
 757            [1]
 758        );
 759    });
 760
 761    // wait for throttle
 762    cx.run_until_parked();
 763
 764    let (request3, respond_third) = requests.predict.next().await.unwrap();
 765
 766    let first_response = model_response(request1, SIMPLE_DIFF);
 767    let first_id = first_response.id.clone();
 768    respond_first.send(first_response).unwrap();
 769
 770    cx.run_until_parked();
 771
 772    ep_store.update(cx, |ep_store, cx| {
 773        // current prediction is first
 774        assert_eq!(
 775            ep_store
 776                .prediction_at(&buffer, None, &project, cx)
 777                .unwrap()
 778                .id
 779                .0,
 780            first_id
 781        );
 782    });
 783
 784    let cancelled_response = model_response(request2, SIMPLE_DIFF);
 785    let cancelled_id = cancelled_response.id.clone();
 786    respond_second.send(cancelled_response).unwrap();
 787
 788    cx.run_until_parked();
 789
 790    ep_store.update(cx, |ep_store, cx| {
 791        // current prediction is still first, since second was cancelled
 792        assert_eq!(
 793            ep_store
 794                .prediction_at(&buffer, None, &project, cx)
 795                .unwrap()
 796                .id
 797                .0,
 798            first_id
 799        );
 800    });
 801
 802    let third_response = model_response(request3, SIMPLE_DIFF);
 803    let third_response_id = third_response.id.clone();
 804    respond_third.send(third_response).unwrap();
 805
 806    cx.run_until_parked();
 807
 808    ep_store.update(cx, |ep_store, cx| {
 809        // third completes and replaces first
 810        assert_eq!(
 811            ep_store
 812                .prediction_at(&buffer, None, &project, cx)
 813                .unwrap()
 814                .id
 815                .0,
 816            third_response_id
 817        );
 818    });
 819
 820    // second is reported as rejected
 821    let (reject_request, _) = requests.reject.next().await.unwrap();
 822
 823    cx.run_until_parked();
 824
 825    assert_eq!(
 826        &reject_request.rejections,
 827        &[
 828            EditPredictionRejection {
 829                request_id: cancelled_id,
 830                reason: EditPredictionRejectReason::Canceled,
 831                was_shown: false
 832            },
 833            EditPredictionRejection {
 834                request_id: first_id,
 835                reason: EditPredictionRejectReason::Replaced,
 836                was_shown: false
 837            }
 838        ]
 839    );
 840}
 841
 842#[gpui::test]
 843async fn test_rejections_flushing(cx: &mut TestAppContext) {
 844    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 845
 846    ep_store.update(cx, |ep_store, _cx| {
 847        ep_store.reject_prediction(
 848            EditPredictionId("test-1".into()),
 849            EditPredictionRejectReason::Discarded,
 850            false,
 851        );
 852        ep_store.reject_prediction(
 853            EditPredictionId("test-2".into()),
 854            EditPredictionRejectReason::Canceled,
 855            true,
 856        );
 857    });
 858
 859    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
 860    cx.run_until_parked();
 861
 862    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
 863    respond_tx.send(()).unwrap();
 864
 865    // batched
 866    assert_eq!(reject_request.rejections.len(), 2);
 867    assert_eq!(
 868        reject_request.rejections[0],
 869        EditPredictionRejection {
 870            request_id: "test-1".to_string(),
 871            reason: EditPredictionRejectReason::Discarded,
 872            was_shown: false
 873        }
 874    );
 875    assert_eq!(
 876        reject_request.rejections[1],
 877        EditPredictionRejection {
 878            request_id: "test-2".to_string(),
 879            reason: EditPredictionRejectReason::Canceled,
 880            was_shown: true
 881        }
 882    );
 883
 884    // Reaching batch size limit sends without debounce
 885    ep_store.update(cx, |ep_store, _cx| {
 886        for i in 0..70 {
 887            ep_store.reject_prediction(
 888                EditPredictionId(format!("batch-{}", i).into()),
 889                EditPredictionRejectReason::Discarded,
 890                false,
 891            );
 892        }
 893    });
 894
 895    // First MAX/2 items are sent immediately
 896    cx.run_until_parked();
 897    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
 898    respond_tx.send(()).unwrap();
 899
 900    assert_eq!(reject_request.rejections.len(), 50);
 901    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
 902    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
 903
 904    // Remaining items are debounced with the next batch
 905    cx.executor().advance_clock(Duration::from_secs(15));
 906    cx.run_until_parked();
 907
 908    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
 909    respond_tx.send(()).unwrap();
 910
 911    assert_eq!(reject_request.rejections.len(), 20);
 912    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
 913    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
 914
 915    // Request failure
 916    ep_store.update(cx, |ep_store, _cx| {
 917        ep_store.reject_prediction(
 918            EditPredictionId("retry-1".into()),
 919            EditPredictionRejectReason::Discarded,
 920            false,
 921        );
 922    });
 923
 924    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
 925    cx.run_until_parked();
 926
 927    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
 928    assert_eq!(reject_request.rejections.len(), 1);
 929    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
 930    // Simulate failure
 931    drop(_respond_tx);
 932
 933    // Add another rejection
 934    ep_store.update(cx, |ep_store, _cx| {
 935        ep_store.reject_prediction(
 936            EditPredictionId("retry-2".into()),
 937            EditPredictionRejectReason::Discarded,
 938            false,
 939        );
 940    });
 941
 942    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
 943    cx.run_until_parked();
 944
 945    // Retry should include both the failed item and the new one
 946    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
 947    respond_tx.send(()).unwrap();
 948
 949    assert_eq!(reject_request.rejections.len(), 2);
 950    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
 951    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
 952}
 953
 954// Skipped until we start including diagnostics in prompt
 955// #[gpui::test]
 956// async fn test_request_diagnostics(cx: &mut TestAppContext) {
 957//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
 958//     let fs = FakeFs::new(cx.executor());
 959//     fs.insert_tree(
 960//         "/root",
 961//         json!({
 962//             "foo.md": "Hello!\nBye"
 963//         }),
 964//     )
 965//     .await;
 966//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 967
 968//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
 969//     let diagnostic = lsp::Diagnostic {
 970//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 971//         severity: Some(lsp::DiagnosticSeverity::ERROR),
 972//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
 973//         ..Default::default()
 974//     };
 975
 976//     project.update(cx, |project, cx| {
 977//         project.lsp_store().update(cx, |lsp_store, cx| {
 978//             // Create some diagnostics
 979//             lsp_store
 980//                 .update_diagnostics(
 981//                     LanguageServerId(0),
 982//                     lsp::PublishDiagnosticsParams {
 983//                         uri: path_to_buffer_uri.clone(),
 984//                         diagnostics: vec![diagnostic],
 985//                         version: None,
 986//                     },
 987//                     None,
 988//                     language::DiagnosticSourceKind::Pushed,
 989//                     &[],
 990//                     cx,
 991//                 )
 992//                 .unwrap();
 993//         });
 994//     });
 995
 996//     let buffer = project
 997//         .update(cx, |project, cx| {
 998//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 999//             project.open_buffer(path, cx)
1000//         })
1001//         .await
1002//         .unwrap();
1003
1004//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1005//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1006
1007//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1008//         ep_store.request_prediction(&project, &buffer, position, cx)
1009//     });
1010
1011//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1012
1013//     assert_eq!(request.diagnostic_groups.len(), 1);
1014//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1015//         .unwrap();
1016//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1017//     assert_eq!(
1018//         value,
1019//         json!({
1020//             "entries": [{
1021//                 "range": {
1022//                     "start": 8,
1023//                     "end": 10
1024//                 },
1025//                 "diagnostic": {
1026//                     "source": null,
1027//                     "code": null,
1028//                     "code_description": null,
1029//                     "severity": 1,
1030//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1031//                     "markdown": null,
1032//                     "group_id": 0,
1033//                     "is_primary": true,
1034//                     "is_disk_based": false,
1035//                     "is_unnecessary": false,
1036//                     "source_kind": "Pushed",
1037//                     "data": null,
1038//                     "underline": true
1039//                 }
1040//             }],
1041//             "primary_ix": 0
1042//         })
1043//     );
1044// }
1045
1046// Generate a model response that would apply the given diff to the active file.
1047fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
1048    let prompt = match &request.messages[0] {
1049        open_ai::RequestMessage::User {
1050            content: open_ai::MessageContent::Plain(content),
1051        } => content,
1052        _ => panic!("unexpected request {request:?}"),
1053    };
1054
1055    let open = "<editable_region>\n";
1056    let close = "</editable_region>";
1057    let cursor = "<|user_cursor|>";
1058
1059    let start_ix = open.len() + prompt.find(open).unwrap();
1060    let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
1061    let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
1062    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1063
1064    open_ai::Response {
1065        id: Uuid::new_v4().to_string(),
1066        object: "response".into(),
1067        created: 0,
1068        model: "model".into(),
1069        choices: vec![open_ai::Choice {
1070            index: 0,
1071            message: open_ai::RequestMessage::Assistant {
1072                content: Some(open_ai::MessageContent::Plain(new_excerpt)),
1073                tool_calls: vec![],
1074            },
1075            finish_reason: None,
1076        }],
1077        usage: Usage {
1078            prompt_tokens: 0,
1079            completion_tokens: 0,
1080            total_tokens: 0,
1081        },
1082    }
1083}
1084
1085fn prompt_from_request(request: &open_ai::Request) -> &str {
1086    assert_eq!(request.messages.len(), 1);
1087    let open_ai::RequestMessage::User {
1088        content: open_ai::MessageContent::Plain(content),
1089        ..
1090    } = &request.messages[0]
1091    else {
1092        panic!(
1093            "Request does not have single user message of type Plain. {:#?}",
1094            request
1095        );
1096    };
1097    content
1098}
1099
1100struct RequestChannels {
1101    predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
1102    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1103}
1104
1105fn init_test_with_fake_client(
1106    cx: &mut TestAppContext,
1107) -> (Entity<EditPredictionStore>, RequestChannels) {
1108    cx.update(move |cx| {
1109        let settings_store = SettingsStore::test(cx);
1110        cx.set_global(settings_store);
1111        zlog::init_test();
1112
1113        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1114        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1115
1116        let http_client = FakeHttpClient::create({
1117            move |req| {
1118                let uri = req.uri().path().to_string();
1119                let mut body = req.into_body();
1120                let predict_req_tx = predict_req_tx.clone();
1121                let reject_req_tx = reject_req_tx.clone();
1122                async move {
1123                    let resp = match uri.as_str() {
1124                        "/client/llm_tokens" => serde_json::to_string(&json!({
1125                            "token": "test"
1126                        }))
1127                        .unwrap(),
1128                        "/predict_edits/raw" => {
1129                            let mut buf = Vec::new();
1130                            body.read_to_end(&mut buf).await.ok();
1131                            let req = serde_json::from_slice(&buf).unwrap();
1132
1133                            let (res_tx, res_rx) = oneshot::channel();
1134                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1135                            serde_json::to_string(&res_rx.await?).unwrap()
1136                        }
1137                        "/predict_edits/reject" => {
1138                            let mut buf = Vec::new();
1139                            body.read_to_end(&mut buf).await.ok();
1140                            let req = serde_json::from_slice(&buf).unwrap();
1141
1142                            let (res_tx, res_rx) = oneshot::channel();
1143                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1144                            serde_json::to_string(&res_rx.await?).unwrap()
1145                        }
1146                        _ => {
1147                            panic!("Unexpected path: {}", uri)
1148                        }
1149                    };
1150
1151                    Ok(Response::builder().body(resp.into()).unwrap())
1152                }
1153            }
1154        });
1155
1156        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1157        client.cloud_client().set_credentials(1, "test".into());
1158
1159        language_model::init(client.clone(), cx);
1160
1161        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1162        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1163
1164        (
1165            ep_store,
1166            RequestChannels {
1167                predict: predict_req_rx,
1168                reject: reject_req_rx,
1169            },
1170        )
1171    })
1172}
1173
1174const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1175
1176#[gpui::test]
1177async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1178    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1179    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1180        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1181    });
1182
1183    let edit_preview = cx
1184        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1185        .await;
1186
1187    let prediction = EditPrediction {
1188        edits,
1189        edit_preview,
1190        buffer: buffer.clone(),
1191        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1192        id: EditPredictionId("the-id".into()),
1193        inputs: ZetaPromptInput {
1194            events: Default::default(),
1195            related_files: Default::default(),
1196            cursor_path: Path::new("").into(),
1197            cursor_excerpt: "".into(),
1198            editable_range_in_excerpt: 0..0,
1199            cursor_offset_in_excerpt: 0,
1200        },
1201        buffer_snapshotted_at: Instant::now(),
1202        response_received_at: Instant::now(),
1203    };
1204
1205    cx.update(|cx| {
1206        assert_eq!(
1207            from_completion_edits(
1208                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1209                &buffer,
1210                cx
1211            ),
1212            vec![(2..5, "REM".into()), (9..11, "".into())]
1213        );
1214
1215        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1216        assert_eq!(
1217            from_completion_edits(
1218                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1219                &buffer,
1220                cx
1221            ),
1222            vec![(2..2, "REM".into()), (6..8, "".into())]
1223        );
1224
1225        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1226        assert_eq!(
1227            from_completion_edits(
1228                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1229                &buffer,
1230                cx
1231            ),
1232            vec![(2..5, "REM".into()), (9..11, "".into())]
1233        );
1234
1235        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1236        assert_eq!(
1237            from_completion_edits(
1238                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1239                &buffer,
1240                cx
1241            ),
1242            vec![(3..3, "EM".into()), (7..9, "".into())]
1243        );
1244
1245        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1246        assert_eq!(
1247            from_completion_edits(
1248                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1249                &buffer,
1250                cx
1251            ),
1252            vec![(4..4, "M".into()), (8..10, "".into())]
1253        );
1254
1255        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1256        assert_eq!(
1257            from_completion_edits(
1258                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1259                &buffer,
1260                cx
1261            ),
1262            vec![(9..11, "".into())]
1263        );
1264
1265        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1266        assert_eq!(
1267            from_completion_edits(
1268                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1269                &buffer,
1270                cx
1271            ),
1272            vec![(4..4, "M".into()), (8..10, "".into())]
1273        );
1274
1275        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1276        assert_eq!(
1277            from_completion_edits(
1278                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1279                &buffer,
1280                cx
1281            ),
1282            vec![(4..4, "M".into())]
1283        );
1284
1285        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1286        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1287    })
1288}
1289
1290#[gpui::test]
1291async fn test_clean_up_diff(cx: &mut TestAppContext) {
1292    init_test(cx);
1293
1294    assert_eq!(
1295        apply_edit_prediction(
1296            indoc! {"
1297                    fn main() {
1298                        let word_1 = \"lorem\";
1299                        let range = word.len()..word.len();
1300                    }
1301                "},
1302            indoc! {"
1303                    <|editable_region_start|>
1304                    fn main() {
1305                        let word_1 = \"lorem\";
1306                        let range = word_1.len()..word_1.len();
1307                    }
1308
1309                    <|editable_region_end|>
1310                "},
1311            cx,
1312        )
1313        .await,
1314        indoc! {"
1315                fn main() {
1316                    let word_1 = \"lorem\";
1317                    let range = word_1.len()..word_1.len();
1318                }
1319            "},
1320    );
1321
1322    assert_eq!(
1323        apply_edit_prediction(
1324            indoc! {"
1325                    fn main() {
1326                        let story = \"the quick\"
1327                    }
1328                "},
1329            indoc! {"
1330                    <|editable_region_start|>
1331                    fn main() {
1332                        let story = \"the quick brown fox jumps over the lazy dog\";
1333                    }
1334
1335                    <|editable_region_end|>
1336                "},
1337            cx,
1338        )
1339        .await,
1340        indoc! {"
1341                fn main() {
1342                    let story = \"the quick brown fox jumps over the lazy dog\";
1343                }
1344            "},
1345    );
1346}
1347
1348#[gpui::test]
1349async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1350    init_test(cx);
1351
1352    let buffer_content = "lorem\n";
1353    let completion_response = indoc! {"
1354            ```animals.js
1355            <|start_of_file|>
1356            <|editable_region_start|>
1357            lorem
1358            ipsum
1359            <|editable_region_end|>
1360            ```"};
1361
1362    assert_eq!(
1363        apply_edit_prediction(buffer_content, completion_response, cx).await,
1364        "lorem\nipsum"
1365    );
1366}
1367
1368#[gpui::test]
1369async fn test_can_collect_data(cx: &mut TestAppContext) {
1370    init_test(cx);
1371
1372    let fs = project::FakeFs::new(cx.executor());
1373    fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1374        .await;
1375
1376    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1377    let buffer = project
1378        .update(cx, |project, cx| {
1379            project.open_local_buffer(path!("/project/src/main.rs"), cx)
1380        })
1381        .await
1382        .unwrap();
1383
1384    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1385    ep_store.update(cx, |ep_store, _cx| {
1386        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1387    });
1388
1389    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1390    assert_eq!(
1391        captured_request.lock().clone().unwrap().can_collect_data,
1392        true
1393    );
1394
1395    ep_store.update(cx, |ep_store, _cx| {
1396        ep_store.data_collection_choice = DataCollectionChoice::Disabled
1397    });
1398
1399    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1400    assert_eq!(
1401        captured_request.lock().clone().unwrap().can_collect_data,
1402        false
1403    );
1404}
1405
1406#[gpui::test]
1407async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1408    init_test(cx);
1409
1410    let fs = project::FakeFs::new(cx.executor());
1411    let project = Project::test(fs.clone(), [], cx).await;
1412
1413    let buffer = cx.new(|_cx| {
1414        Buffer::remote(
1415            language::BufferId::new(1).unwrap(),
1416            ReplicaId::new(1),
1417            language::Capability::ReadWrite,
1418            "fn main() {\n    println!(\"Hello\");\n}",
1419        )
1420    });
1421
1422    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1423    ep_store.update(cx, |ep_store, _cx| {
1424        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1425    });
1426
1427    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1428    assert_eq!(
1429        captured_request.lock().clone().unwrap().can_collect_data,
1430        false
1431    );
1432}
1433
1434#[gpui::test]
1435async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1436    init_test(cx);
1437
1438    let fs = project::FakeFs::new(cx.executor());
1439    fs.insert_tree(
1440        path!("/project"),
1441        json!({
1442            "LICENSE": BSD_0_TXT,
1443            ".env": "SECRET_KEY=secret"
1444        }),
1445    )
1446    .await;
1447
1448    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1449    let buffer = project
1450        .update(cx, |project, cx| {
1451            project.open_local_buffer("/project/.env", cx)
1452        })
1453        .await
1454        .unwrap();
1455
1456    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1457    ep_store.update(cx, |ep_store, _cx| {
1458        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1459    });
1460
1461    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1462    assert_eq!(
1463        captured_request.lock().clone().unwrap().can_collect_data,
1464        false
1465    );
1466}
1467
1468#[gpui::test]
1469async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1470    init_test(cx);
1471
1472    let fs = project::FakeFs::new(cx.executor());
1473    let project = Project::test(fs.clone(), [], cx).await;
1474    let buffer = cx.new(|cx| Buffer::local("", cx));
1475
1476    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1477    ep_store.update(cx, |ep_store, _cx| {
1478        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1479    });
1480
1481    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1482    assert_eq!(
1483        captured_request.lock().clone().unwrap().can_collect_data,
1484        false
1485    );
1486}
1487
1488#[gpui::test]
1489async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1490    init_test(cx);
1491
1492    let fs = project::FakeFs::new(cx.executor());
1493    fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1494        .await;
1495
1496    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1497    let buffer = project
1498        .update(cx, |project, cx| {
1499            project.open_local_buffer("/project/main.rs", cx)
1500        })
1501        .await
1502        .unwrap();
1503
1504    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1505    ep_store.update(cx, |ep_store, _cx| {
1506        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1507    });
1508
1509    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1510    assert_eq!(
1511        captured_request.lock().clone().unwrap().can_collect_data,
1512        false
1513    );
1514}
1515
1516#[gpui::test]
1517async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1518    init_test(cx);
1519
1520    let fs = project::FakeFs::new(cx.executor());
1521    fs.insert_tree(
1522        path!("/open_source_worktree"),
1523        json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1524    )
1525    .await;
1526    fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1527        .await;
1528
1529    let project = Project::test(
1530        fs.clone(),
1531        [
1532            path!("/open_source_worktree").as_ref(),
1533            path!("/closed_source_worktree").as_ref(),
1534        ],
1535        cx,
1536    )
1537    .await;
1538    let buffer = project
1539        .update(cx, |project, cx| {
1540            project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1541        })
1542        .await
1543        .unwrap();
1544
1545    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1546    ep_store.update(cx, |ep_store, _cx| {
1547        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1548    });
1549
1550    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1551    assert_eq!(
1552        captured_request.lock().clone().unwrap().can_collect_data,
1553        true
1554    );
1555
1556    let closed_source_file = project
1557        .update(cx, |project, cx| {
1558            let worktree2 = project
1559                .worktree_for_root_name("closed_source_worktree", cx)
1560                .unwrap();
1561            worktree2.update(cx, |worktree2, cx| {
1562                worktree2.load_file(rel_path("main.rs"), cx)
1563            })
1564        })
1565        .await
1566        .unwrap()
1567        .file;
1568
1569    buffer.update(cx, |buffer, cx| {
1570        buffer.file_updated(closed_source_file, cx);
1571    });
1572
1573    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1574    assert_eq!(
1575        captured_request.lock().clone().unwrap().can_collect_data,
1576        false
1577    );
1578}
1579
1580#[gpui::test]
1581async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1582    init_test(cx);
1583
1584    let fs = project::FakeFs::new(cx.executor());
1585    fs.insert_tree(
1586        path!("/worktree1"),
1587        json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1588    )
1589    .await;
1590    fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1591        .await;
1592
1593    let project = Project::test(
1594        fs.clone(),
1595        [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1596        cx,
1597    )
1598    .await;
1599    let buffer = project
1600        .update(cx, |project, cx| {
1601            project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1602        })
1603        .await
1604        .unwrap();
1605    let private_buffer = project
1606        .update(cx, |project, cx| {
1607            project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1608        })
1609        .await
1610        .unwrap();
1611
1612    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1613    ep_store.update(cx, |ep_store, _cx| {
1614        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1615    });
1616
1617    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1618    assert_eq!(
1619        captured_request.lock().clone().unwrap().can_collect_data,
1620        true
1621    );
1622
1623    // this has a side effect of registering the buffer to watch for edits
1624    run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1625    assert_eq!(
1626        captured_request.lock().clone().unwrap().can_collect_data,
1627        false
1628    );
1629
1630    private_buffer.update(cx, |private_buffer, cx| {
1631        private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1632    });
1633
1634    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1635    assert_eq!(
1636        captured_request.lock().clone().unwrap().can_collect_data,
1637        false
1638    );
1639
1640    // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1641    // included
1642    buffer.update(cx, |buffer, cx| {
1643        buffer.edit(
1644            [(
1645                0..0,
1646                " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1647            )],
1648            None,
1649            cx,
1650        );
1651    });
1652
1653    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1654    assert_eq!(
1655        captured_request.lock().clone().unwrap().can_collect_data,
1656        true
1657    );
1658}
1659
1660fn init_test(cx: &mut TestAppContext) {
1661    cx.update(|cx| {
1662        let settings_store = SettingsStore::test(cx);
1663        cx.set_global(settings_store);
1664    });
1665}
1666
1667async fn apply_edit_prediction(
1668    buffer_content: &str,
1669    completion_response: &str,
1670    cx: &mut TestAppContext,
1671) -> String {
1672    let fs = project::FakeFs::new(cx.executor());
1673    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1674    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1675    let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1676    *response.lock() = completion_response.to_string();
1677    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1678    buffer.update(cx, |buffer, cx| {
1679        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1680    });
1681    buffer.read_with(cx, |buffer, _| buffer.text())
1682}
1683
1684async fn run_edit_prediction(
1685    buffer: &Entity<Buffer>,
1686    project: &Entity<Project>,
1687    ep_store: &Entity<EditPredictionStore>,
1688    cx: &mut TestAppContext,
1689) -> EditPrediction {
1690    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1691    ep_store.update(cx, |ep_store, cx| {
1692        ep_store.register_buffer(buffer, &project, cx)
1693    });
1694    cx.background_executor.run_until_parked();
1695    let prediction_task = ep_store.update(cx, |ep_store, cx| {
1696        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
1697    });
1698    prediction_task.await.unwrap().unwrap().prediction.unwrap()
1699}
1700
1701async fn make_test_ep_store(
1702    project: &Entity<Project>,
1703    cx: &mut TestAppContext,
1704) -> (
1705    Entity<EditPredictionStore>,
1706    Arc<Mutex<Option<PredictEditsBody>>>,
1707    Arc<Mutex<String>>,
1708) {
1709    let default_response = indoc! {"
1710            ```main.rs
1711            <|start_of_file|>
1712            <|editable_region_start|>
1713            hello world
1714            <|editable_region_end|>
1715            ```"
1716    };
1717    let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
1718    let completion_response: Arc<Mutex<String>> =
1719        Arc::new(Mutex::new(default_response.to_string()));
1720    let http_client = FakeHttpClient::create({
1721        let captured_request = captured_request.clone();
1722        let completion_response = completion_response.clone();
1723        let mut next_request_id = 0;
1724        move |req| {
1725            let captured_request = captured_request.clone();
1726            let completion_response = completion_response.clone();
1727            async move {
1728                match (req.method(), req.uri().path()) {
1729                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1730                        .status(200)
1731                        .body(
1732                            serde_json::to_string(&CreateLlmTokenResponse {
1733                                token: LlmToken("the-llm-token".to_string()),
1734                            })
1735                            .unwrap()
1736                            .into(),
1737                        )
1738                        .unwrap()),
1739                    (&Method::POST, "/predict_edits/v2") => {
1740                        let mut request_body = String::new();
1741                        req.into_body().read_to_string(&mut request_body).await?;
1742                        *captured_request.lock() =
1743                            Some(serde_json::from_str(&request_body).unwrap());
1744                        next_request_id += 1;
1745                        Ok(http_client::Response::builder()
1746                            .status(200)
1747                            .body(
1748                                serde_json::to_string(&PredictEditsResponse {
1749                                    request_id: format!("request-{next_request_id}"),
1750                                    output_excerpt: completion_response.lock().clone(),
1751                                })
1752                                .unwrap()
1753                                .into(),
1754                            )
1755                            .unwrap())
1756                    }
1757                    _ => Ok(http_client::Response::builder()
1758                        .status(404)
1759                        .body("Not Found".into())
1760                        .unwrap()),
1761                }
1762            }
1763        }
1764    });
1765
1766    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1767    cx.update(|cx| {
1768        RefreshLlmTokenListener::register(client.clone(), cx);
1769    });
1770    let _server = FakeServer::for_client(42, &client, cx).await;
1771
1772    let ep_store = cx.new(|cx| {
1773        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
1774        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
1775
1776        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
1777        for worktree in worktrees {
1778            let worktree_id = worktree.read(cx).id();
1779            ep_store
1780                .get_or_init_project(project, cx)
1781                .license_detection_watchers
1782                .entry(worktree_id)
1783                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
1784        }
1785
1786        ep_store
1787    });
1788
1789    (ep_store, captured_request, completion_response)
1790}
1791
1792fn to_completion_edits(
1793    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
1794    buffer: &Entity<Buffer>,
1795    cx: &App,
1796) -> Vec<(Range<Anchor>, Arc<str>)> {
1797    let buffer = buffer.read(cx);
1798    iterator
1799        .into_iter()
1800        .map(|(range, text)| {
1801            (
1802                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1803                text,
1804            )
1805        })
1806        .collect()
1807}
1808
1809fn from_completion_edits(
1810    editor_edits: &[(Range<Anchor>, Arc<str>)],
1811    buffer: &Entity<Buffer>,
1812    cx: &App,
1813) -> Vec<(Range<usize>, Arc<str>)> {
1814    let buffer = buffer.read(cx);
1815    editor_edits
1816        .iter()
1817        .map(|(range, text)| {
1818            (
1819                range.start.to_offset(buffer)..range.end.to_offset(buffer),
1820                text.clone(),
1821            )
1822        })
1823        .collect()
1824}
1825
1826#[ctor::ctor]
1827fn init_logger() {
1828    zlog::init_test();
1829}