edit_prediction_tests.rs

   1use super::*;
   2use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
   3use client::{UserStore, test::FakeServer};
   4use clock::{FakeSystemClock, ReplicaId};
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
   8    RejectEditPredictionsBody,
   9};
  10use futures::{
  11    AsyncReadExt, StreamExt,
  12    channel::{mpsc, oneshot},
  13};
  14use gpui::{
  15    Entity, TestAppContext,
  16    http_client::{FakeHttpClient, Response},
  17};
  18use indoc::indoc;
  19use language::{Point, ToOffset as _};
  20use lsp::LanguageServerId;
  21use open_ai::Usage;
  22use parking_lot::Mutex;
  23use pretty_assertions::{assert_eq, assert_matches};
  24use project::{FakeFs, Project};
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{path::Path, sync::Arc, time::Duration};
  28use util::{path, rel_path::rel_path};
  29use uuid::Uuid;
  30use zeta_prompt::ZetaPromptInput;
  31
  32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  33
  34#[gpui::test]
  35async fn test_current_state(cx: &mut TestAppContext) {
  36    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  37    let fs = FakeFs::new(cx.executor());
  38    fs.insert_tree(
  39        "/root",
  40        json!({
  41            "1.txt": "Hello!\nHow\nBye\n",
  42            "2.txt": "Hola!\nComo\nAdios\n"
  43        }),
  44    )
  45    .await;
  46    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  47
  48    let buffer1 = project
  49        .update(cx, |project, cx| {
  50            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  51            project.set_active_path(Some(path.clone()), cx);
  52            project.open_buffer(path, cx)
  53        })
  54        .await
  55        .unwrap();
  56    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  57    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  58
  59    ep_store.update(cx, |ep_store, cx| {
  60        ep_store.register_project(&project, cx);
  61        ep_store.register_buffer(&buffer1, &project, cx);
  62    });
  63
  64    // Prediction for current file
  65
  66    ep_store.update(cx, |ep_store, cx| {
  67        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  68    });
  69    let (request, respond_tx) = requests.predict.next().await.unwrap();
  70
  71    respond_tx
  72        .send(model_response(
  73            request,
  74            indoc! {r"
  75                --- a/root/1.txt
  76                +++ b/root/1.txt
  77                @@ ... @@
  78                 Hello!
  79                -How
  80                +How are you?
  81                 Bye
  82            "},
  83        ))
  84        .unwrap();
  85
  86    cx.run_until_parked();
  87
  88    ep_store.update(cx, |ep_store, cx| {
  89        let prediction = ep_store
  90            .prediction_at(&buffer1, None, &project, cx)
  91            .unwrap();
  92        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  93    });
  94
  95    ep_store.update(cx, |ep_store, _cx| {
  96        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
  97    });
  98
  99    // Prediction for diagnostic in another file
 100
 101    let diagnostic = lsp::Diagnostic {
 102        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 103        severity: Some(lsp::DiagnosticSeverity::ERROR),
 104        message: "Sentence is incomplete".to_string(),
 105        ..Default::default()
 106    };
 107
 108    project.update(cx, |project, cx| {
 109        project.lsp_store().update(cx, |lsp_store, cx| {
 110            lsp_store
 111                .update_diagnostics(
 112                    LanguageServerId(0),
 113                    lsp::PublishDiagnosticsParams {
 114                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 115                        diagnostics: vec![diagnostic],
 116                        version: None,
 117                    },
 118                    None,
 119                    language::DiagnosticSourceKind::Pushed,
 120                    &[],
 121                    cx,
 122                )
 123                .unwrap();
 124        });
 125    });
 126
 127    let (request, respond_tx) = requests.predict.next().await.unwrap();
 128    respond_tx
 129        .send(model_response(
 130            request,
 131            indoc! {r#"
 132                --- a/root/2.txt
 133                +++ b/root/2.txt
 134                @@ ... @@
 135                 Hola!
 136                -Como
 137                +Como estas?
 138                 Adios
 139            "#},
 140        ))
 141        .unwrap();
 142    cx.run_until_parked();
 143
 144    ep_store.update(cx, |ep_store, cx| {
 145        let prediction = ep_store
 146            .prediction_at(&buffer1, None, &project, cx)
 147            .unwrap();
 148        assert_matches!(
 149            prediction,
 150            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 151        );
 152    });
 153
 154    let buffer2 = project
 155        .update(cx, |project, cx| {
 156            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 157            project.open_buffer(path, cx)
 158        })
 159        .await
 160        .unwrap();
 161
 162    ep_store.update(cx, |ep_store, cx| {
 163        let prediction = ep_store
 164            .prediction_at(&buffer2, None, &project, cx)
 165            .unwrap();
 166        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 167    });
 168}
 169
 170#[gpui::test]
 171async fn test_simple_request(cx: &mut TestAppContext) {
 172    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 173    let fs = FakeFs::new(cx.executor());
 174    fs.insert_tree(
 175        "/root",
 176        json!({
 177            "foo.md":  "Hello!\nHow\nBye\n"
 178        }),
 179    )
 180    .await;
 181    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 182
 183    let buffer = project
 184        .update(cx, |project, cx| {
 185            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 186            project.open_buffer(path, cx)
 187        })
 188        .await
 189        .unwrap();
 190    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 191    let position = snapshot.anchor_before(language::Point::new(1, 3));
 192
 193    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 194        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 195    });
 196
 197    let (request, respond_tx) = requests.predict.next().await.unwrap();
 198
 199    // TODO Put back when we have a structured request again
 200    // assert_eq!(
 201    //     request.excerpt_path.as_ref(),
 202    //     Path::new(path!("root/foo.md"))
 203    // );
 204    // assert_eq!(
 205    //     request.cursor_point,
 206    //     Point {
 207    //         line: Line(1),
 208    //         column: 3
 209    //     }
 210    // );
 211
 212    respond_tx
 213        .send(model_response(
 214            request,
 215            indoc! { r"
 216                --- a/root/foo.md
 217                +++ b/root/foo.md
 218                @@ ... @@
 219                 Hello!
 220                -How
 221                +How are you?
 222                 Bye
 223            "},
 224        ))
 225        .unwrap();
 226
 227    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 228
 229    assert_eq!(prediction.edits.len(), 1);
 230    assert_eq!(
 231        prediction.edits[0].0.to_point(&snapshot).start,
 232        language::Point::new(1, 3)
 233    );
 234    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 235}
 236
 237#[gpui::test]
 238async fn test_request_events(cx: &mut TestAppContext) {
 239    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 240    let fs = FakeFs::new(cx.executor());
 241    fs.insert_tree(
 242        "/root",
 243        json!({
 244            "foo.md": "Hello!\n\nBye\n"
 245        }),
 246    )
 247    .await;
 248    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 249
 250    let buffer = project
 251        .update(cx, |project, cx| {
 252            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 253            project.open_buffer(path, cx)
 254        })
 255        .await
 256        .unwrap();
 257
 258    ep_store.update(cx, |ep_store, cx| {
 259        ep_store.register_buffer(&buffer, &project, cx);
 260    });
 261
 262    buffer.update(cx, |buffer, cx| {
 263        buffer.edit(vec![(7..7, "How")], None, cx);
 264    });
 265
 266    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 267    let position = snapshot.anchor_before(language::Point::new(1, 3));
 268
 269    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 270        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 271    });
 272
 273    let (request, respond_tx) = requests.predict.next().await.unwrap();
 274
 275    let prompt = prompt_from_request(&request);
 276    assert!(
 277        prompt.contains(indoc! {"
 278        --- a/root/foo.md
 279        +++ b/root/foo.md
 280        @@ -1,3 +1,3 @@
 281         Hello!
 282        -
 283        +How
 284         Bye
 285    "}),
 286        "{prompt}"
 287    );
 288
 289    respond_tx
 290        .send(model_response(
 291            request,
 292            indoc! {r#"
 293                --- a/root/foo.md
 294                +++ b/root/foo.md
 295                @@ ... @@
 296                 Hello!
 297                -How
 298                +How are you?
 299                 Bye
 300        "#},
 301        ))
 302        .unwrap();
 303
 304    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 305
 306    assert_eq!(prediction.edits.len(), 1);
 307    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 308}
 309
 310#[gpui::test]
 311async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 312    let (ep_store, _requests) = init_test_with_fake_client(cx);
 313    let fs = FakeFs::new(cx.executor());
 314    fs.insert_tree(
 315        "/root",
 316        json!({
 317            "foo.md": "Hello!\n\nBye\n"
 318        }),
 319    )
 320    .await;
 321    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 322
 323    let buffer = project
 324        .update(cx, |project, cx| {
 325            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 326            project.open_buffer(path, cx)
 327        })
 328        .await
 329        .unwrap();
 330
 331    ep_store.update(cx, |ep_store, cx| {
 332        ep_store.register_buffer(&buffer, &project, cx);
 333    });
 334
 335    // First burst: insert "How"
 336    buffer.update(cx, |buffer, cx| {
 337        buffer.edit(vec![(7..7, "How")], None, cx);
 338    });
 339
 340    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 341    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 342    cx.run_until_parked();
 343
 344    // Second burst: append " are you?" immediately after "How" on the same line.
 345    //
 346    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 347    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 348    buffer.update(cx, |buffer, cx| {
 349        buffer.edit(vec![(10..10, " are you?")], None, cx);
 350    });
 351
 352    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 353    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 354    buffer.update(cx, |buffer, cx| {
 355        buffer.edit(vec![(19..19, "!")], None, cx);
 356    });
 357
 358    // Without time-based splitting, there is one event.
 359    let events = ep_store.update(cx, |ep_store, cx| {
 360        ep_store.edit_history_for_project(&project, cx)
 361    });
 362    assert_eq!(events.len(), 1);
 363    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].as_ref();
 364    assert_eq!(
 365        diff.as_str(),
 366        indoc! {"
 367            @@ -1,3 +1,3 @@
 368             Hello!
 369            -
 370            +How are you?!
 371             Bye
 372        "}
 373    );
 374
 375    // With time-based splitting, there are two distinct events.
 376    let events = ep_store.update(cx, |ep_store, cx| {
 377        ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
 378    });
 379    assert_eq!(events.len(), 2);
 380    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].as_ref();
 381    assert_eq!(
 382        diff.as_str(),
 383        indoc! {"
 384            @@ -1,3 +1,3 @@
 385             Hello!
 386            -
 387            +How
 388             Bye
 389        "}
 390    );
 391
 392    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].as_ref();
 393    assert_eq!(
 394        diff.as_str(),
 395        indoc! {"
 396            @@ -1,3 +1,3 @@
 397             Hello!
 398            -How
 399            +How are you?!
 400             Bye
 401        "}
 402    );
 403}
 404
 405#[gpui::test]
 406async fn test_empty_prediction(cx: &mut TestAppContext) {
 407    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 408    let fs = FakeFs::new(cx.executor());
 409    fs.insert_tree(
 410        "/root",
 411        json!({
 412            "foo.md":  "Hello!\nHow\nBye\n"
 413        }),
 414    )
 415    .await;
 416    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 417
 418    let buffer = project
 419        .update(cx, |project, cx| {
 420            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 421            project.open_buffer(path, cx)
 422        })
 423        .await
 424        .unwrap();
 425    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 426    let position = snapshot.anchor_before(language::Point::new(1, 3));
 427
 428    ep_store.update(cx, |ep_store, cx| {
 429        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 430    });
 431
 432    let (request, respond_tx) = requests.predict.next().await.unwrap();
 433    let response = model_response(request, "");
 434    let id = response.id.clone();
 435    respond_tx.send(response).unwrap();
 436
 437    cx.run_until_parked();
 438
 439    ep_store.update(cx, |ep_store, cx| {
 440        assert!(
 441            ep_store
 442                .prediction_at(&buffer, None, &project, cx)
 443                .is_none()
 444        );
 445    });
 446
 447    // prediction is reported as rejected
 448    let (reject_request, _) = requests.reject.next().await.unwrap();
 449
 450    assert_eq!(
 451        &reject_request.rejections,
 452        &[EditPredictionRejection {
 453            request_id: id,
 454            reason: EditPredictionRejectReason::Empty,
 455            was_shown: false
 456        }]
 457    );
 458}
 459
 460#[gpui::test]
 461async fn test_interpolated_empty(cx: &mut TestAppContext) {
 462    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 463    let fs = FakeFs::new(cx.executor());
 464    fs.insert_tree(
 465        "/root",
 466        json!({
 467            "foo.md":  "Hello!\nHow\nBye\n"
 468        }),
 469    )
 470    .await;
 471    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 472
 473    let buffer = project
 474        .update(cx, |project, cx| {
 475            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 476            project.open_buffer(path, cx)
 477        })
 478        .await
 479        .unwrap();
 480    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 481    let position = snapshot.anchor_before(language::Point::new(1, 3));
 482
 483    ep_store.update(cx, |ep_store, cx| {
 484        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 485    });
 486
 487    let (request, respond_tx) = requests.predict.next().await.unwrap();
 488
 489    buffer.update(cx, |buffer, cx| {
 490        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 491    });
 492
 493    let response = model_response(request, SIMPLE_DIFF);
 494    let id = response.id.clone();
 495    respond_tx.send(response).unwrap();
 496
 497    cx.run_until_parked();
 498
 499    ep_store.update(cx, |ep_store, cx| {
 500        assert!(
 501            ep_store
 502                .prediction_at(&buffer, None, &project, cx)
 503                .is_none()
 504        );
 505    });
 506
 507    // prediction is reported as rejected
 508    let (reject_request, _) = requests.reject.next().await.unwrap();
 509
 510    assert_eq!(
 511        &reject_request.rejections,
 512        &[EditPredictionRejection {
 513            request_id: id,
 514            reason: EditPredictionRejectReason::InterpolatedEmpty,
 515            was_shown: false
 516        }]
 517    );
 518}
 519
 520const SIMPLE_DIFF: &str = indoc! { r"
 521    --- a/root/foo.md
 522    +++ b/root/foo.md
 523    @@ ... @@
 524     Hello!
 525    -How
 526    +How are you?
 527     Bye
 528"};
 529
 530#[gpui::test]
 531async fn test_replace_current(cx: &mut TestAppContext) {
 532    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 533    let fs = FakeFs::new(cx.executor());
 534    fs.insert_tree(
 535        "/root",
 536        json!({
 537            "foo.md":  "Hello!\nHow\nBye\n"
 538        }),
 539    )
 540    .await;
 541    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 542
 543    let buffer = project
 544        .update(cx, |project, cx| {
 545            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 546            project.open_buffer(path, cx)
 547        })
 548        .await
 549        .unwrap();
 550    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 551    let position = snapshot.anchor_before(language::Point::new(1, 3));
 552
 553    ep_store.update(cx, |ep_store, cx| {
 554        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 555    });
 556
 557    let (request, respond_tx) = requests.predict.next().await.unwrap();
 558    let first_response = model_response(request, SIMPLE_DIFF);
 559    let first_id = first_response.id.clone();
 560    respond_tx.send(first_response).unwrap();
 561
 562    cx.run_until_parked();
 563
 564    ep_store.update(cx, |ep_store, cx| {
 565        assert_eq!(
 566            ep_store
 567                .prediction_at(&buffer, None, &project, cx)
 568                .unwrap()
 569                .id
 570                .0,
 571            first_id
 572        );
 573    });
 574
 575    // a second request is triggered
 576    ep_store.update(cx, |ep_store, cx| {
 577        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 578    });
 579
 580    let (request, respond_tx) = requests.predict.next().await.unwrap();
 581    let second_response = model_response(request, SIMPLE_DIFF);
 582    let second_id = second_response.id.clone();
 583    respond_tx.send(second_response).unwrap();
 584
 585    cx.run_until_parked();
 586
 587    ep_store.update(cx, |ep_store, cx| {
 588        // second replaces first
 589        assert_eq!(
 590            ep_store
 591                .prediction_at(&buffer, None, &project, cx)
 592                .unwrap()
 593                .id
 594                .0,
 595            second_id
 596        );
 597    });
 598
 599    // first is reported as replaced
 600    let (reject_request, _) = requests.reject.next().await.unwrap();
 601
 602    assert_eq!(
 603        &reject_request.rejections,
 604        &[EditPredictionRejection {
 605            request_id: first_id,
 606            reason: EditPredictionRejectReason::Replaced,
 607            was_shown: false
 608        }]
 609    );
 610}
 611
 612#[gpui::test]
 613async fn test_current_preferred(cx: &mut TestAppContext) {
 614    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 615    let fs = FakeFs::new(cx.executor());
 616    fs.insert_tree(
 617        "/root",
 618        json!({
 619            "foo.md":  "Hello!\nHow\nBye\n"
 620        }),
 621    )
 622    .await;
 623    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 624
 625    let buffer = project
 626        .update(cx, |project, cx| {
 627            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 628            project.open_buffer(path, cx)
 629        })
 630        .await
 631        .unwrap();
 632    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 633    let position = snapshot.anchor_before(language::Point::new(1, 3));
 634
 635    ep_store.update(cx, |ep_store, cx| {
 636        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 637    });
 638
 639    let (request, respond_tx) = requests.predict.next().await.unwrap();
 640    let first_response = model_response(request, SIMPLE_DIFF);
 641    let first_id = first_response.id.clone();
 642    respond_tx.send(first_response).unwrap();
 643
 644    cx.run_until_parked();
 645
 646    ep_store.update(cx, |ep_store, cx| {
 647        assert_eq!(
 648            ep_store
 649                .prediction_at(&buffer, None, &project, cx)
 650                .unwrap()
 651                .id
 652                .0,
 653            first_id
 654        );
 655    });
 656
 657    // a second request is triggered
 658    ep_store.update(cx, |ep_store, cx| {
 659        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 660    });
 661
 662    let (request, respond_tx) = requests.predict.next().await.unwrap();
 663    // worse than current prediction
 664    let second_response = model_response(
 665        request,
 666        indoc! { r"
 667            --- a/root/foo.md
 668            +++ b/root/foo.md
 669            @@ ... @@
 670             Hello!
 671            -How
 672            +How are
 673             Bye
 674        "},
 675    );
 676    let second_id = second_response.id.clone();
 677    respond_tx.send(second_response).unwrap();
 678
 679    cx.run_until_parked();
 680
 681    ep_store.update(cx, |ep_store, cx| {
 682        // first is preferred over second
 683        assert_eq!(
 684            ep_store
 685                .prediction_at(&buffer, None, &project, cx)
 686                .unwrap()
 687                .id
 688                .0,
 689            first_id
 690        );
 691    });
 692
 693    // second is reported as rejected
 694    let (reject_request, _) = requests.reject.next().await.unwrap();
 695
 696    assert_eq!(
 697        &reject_request.rejections,
 698        &[EditPredictionRejection {
 699            request_id: second_id,
 700            reason: EditPredictionRejectReason::CurrentPreferred,
 701            was_shown: false
 702        }]
 703    );
 704}
 705
 706#[gpui::test]
 707async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
 708    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 709    let fs = FakeFs::new(cx.executor());
 710    fs.insert_tree(
 711        "/root",
 712        json!({
 713            "foo.md":  "Hello!\nHow\nBye\n"
 714        }),
 715    )
 716    .await;
 717    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 718
 719    let buffer = project
 720        .update(cx, |project, cx| {
 721            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 722            project.open_buffer(path, cx)
 723        })
 724        .await
 725        .unwrap();
 726    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 727    let position = snapshot.anchor_before(language::Point::new(1, 3));
 728
 729    // start two refresh tasks
 730    ep_store.update(cx, |ep_store, cx| {
 731        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 732    });
 733
 734    let (request1, respond_first) = requests.predict.next().await.unwrap();
 735
 736    ep_store.update(cx, |ep_store, cx| {
 737        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 738    });
 739
 740    let (request, respond_second) = requests.predict.next().await.unwrap();
 741
 742    // wait for throttle
 743    cx.run_until_parked();
 744
 745    // second responds first
 746    let second_response = model_response(request, SIMPLE_DIFF);
 747    let second_id = second_response.id.clone();
 748    respond_second.send(second_response).unwrap();
 749
 750    cx.run_until_parked();
 751
 752    ep_store.update(cx, |ep_store, cx| {
 753        // current prediction is second
 754        assert_eq!(
 755            ep_store
 756                .prediction_at(&buffer, None, &project, cx)
 757                .unwrap()
 758                .id
 759                .0,
 760            second_id
 761        );
 762    });
 763
 764    let first_response = model_response(request1, SIMPLE_DIFF);
 765    let first_id = first_response.id.clone();
 766    respond_first.send(first_response).unwrap();
 767
 768    cx.run_until_parked();
 769
 770    ep_store.update(cx, |ep_store, cx| {
 771        // current prediction is still second, since first was cancelled
 772        assert_eq!(
 773            ep_store
 774                .prediction_at(&buffer, None, &project, cx)
 775                .unwrap()
 776                .id
 777                .0,
 778            second_id
 779        );
 780    });
 781
 782    // first is reported as rejected
 783    let (reject_request, _) = requests.reject.next().await.unwrap();
 784
 785    cx.run_until_parked();
 786
 787    assert_eq!(
 788        &reject_request.rejections,
 789        &[EditPredictionRejection {
 790            request_id: first_id,
 791            reason: EditPredictionRejectReason::Canceled,
 792            was_shown: false
 793        }]
 794    );
 795}
 796
 797#[gpui::test]
 798async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
 799    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 800    let fs = FakeFs::new(cx.executor());
 801    fs.insert_tree(
 802        "/root",
 803        json!({
 804            "foo.md":  "Hello!\nHow\nBye\n"
 805        }),
 806    )
 807    .await;
 808    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 809
 810    let buffer = project
 811        .update(cx, |project, cx| {
 812            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 813            project.open_buffer(path, cx)
 814        })
 815        .await
 816        .unwrap();
 817    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 818    let position = snapshot.anchor_before(language::Point::new(1, 3));
 819
 820    // start two refresh tasks
 821    ep_store.update(cx, |ep_store, cx| {
 822        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 823    });
 824
 825    let (request1, respond_first) = requests.predict.next().await.unwrap();
 826
 827    ep_store.update(cx, |ep_store, cx| {
 828        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 829    });
 830
 831    let (request2, respond_second) = requests.predict.next().await.unwrap();
 832
 833    // wait for throttle, so requests are sent
 834    cx.run_until_parked();
 835
 836    ep_store.update(cx, |ep_store, cx| {
 837        // start a third request
 838        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 839
 840        // 2 are pending, so 2nd is cancelled
 841        assert_eq!(
 842            ep_store
 843                .get_or_init_project(&project, cx)
 844                .cancelled_predictions
 845                .iter()
 846                .copied()
 847                .collect::<Vec<_>>(),
 848            [1]
 849        );
 850    });
 851
 852    // wait for throttle
 853    cx.run_until_parked();
 854
 855    let (request3, respond_third) = requests.predict.next().await.unwrap();
 856
 857    let first_response = model_response(request1, SIMPLE_DIFF);
 858    let first_id = first_response.id.clone();
 859    respond_first.send(first_response).unwrap();
 860
 861    cx.run_until_parked();
 862
 863    ep_store.update(cx, |ep_store, cx| {
 864        // current prediction is first
 865        assert_eq!(
 866            ep_store
 867                .prediction_at(&buffer, None, &project, cx)
 868                .unwrap()
 869                .id
 870                .0,
 871            first_id
 872        );
 873    });
 874
 875    let cancelled_response = model_response(request2, SIMPLE_DIFF);
 876    let cancelled_id = cancelled_response.id.clone();
 877    respond_second.send(cancelled_response).unwrap();
 878
 879    cx.run_until_parked();
 880
 881    ep_store.update(cx, |ep_store, cx| {
 882        // current prediction is still first, since second was cancelled
 883        assert_eq!(
 884            ep_store
 885                .prediction_at(&buffer, None, &project, cx)
 886                .unwrap()
 887                .id
 888                .0,
 889            first_id
 890        );
 891    });
 892
 893    let third_response = model_response(request3, SIMPLE_DIFF);
 894    let third_response_id = third_response.id.clone();
 895    respond_third.send(third_response).unwrap();
 896
 897    cx.run_until_parked();
 898
 899    ep_store.update(cx, |ep_store, cx| {
 900        // third completes and replaces first
 901        assert_eq!(
 902            ep_store
 903                .prediction_at(&buffer, None, &project, cx)
 904                .unwrap()
 905                .id
 906                .0,
 907            third_response_id
 908        );
 909    });
 910
 911    // second is reported as rejected
 912    let (reject_request, _) = requests.reject.next().await.unwrap();
 913
 914    cx.run_until_parked();
 915
 916    assert_eq!(
 917        &reject_request.rejections,
 918        &[
 919            EditPredictionRejection {
 920                request_id: cancelled_id,
 921                reason: EditPredictionRejectReason::Canceled,
 922                was_shown: false
 923            },
 924            EditPredictionRejection {
 925                request_id: first_id,
 926                reason: EditPredictionRejectReason::Replaced,
 927                was_shown: false
 928            }
 929        ]
 930    );
 931}
 932
 933#[gpui::test]
 934async fn test_rejections_flushing(cx: &mut TestAppContext) {
 935    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 936
 937    ep_store.update(cx, |ep_store, _cx| {
 938        ep_store.reject_prediction(
 939            EditPredictionId("test-1".into()),
 940            EditPredictionRejectReason::Discarded,
 941            false,
 942        );
 943        ep_store.reject_prediction(
 944            EditPredictionId("test-2".into()),
 945            EditPredictionRejectReason::Canceled,
 946            true,
 947        );
 948    });
 949
 950    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
 951    cx.run_until_parked();
 952
 953    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
 954    respond_tx.send(()).unwrap();
 955
 956    // batched
 957    assert_eq!(reject_request.rejections.len(), 2);
 958    assert_eq!(
 959        reject_request.rejections[0],
 960        EditPredictionRejection {
 961            request_id: "test-1".to_string(),
 962            reason: EditPredictionRejectReason::Discarded,
 963            was_shown: false
 964        }
 965    );
 966    assert_eq!(
 967        reject_request.rejections[1],
 968        EditPredictionRejection {
 969            request_id: "test-2".to_string(),
 970            reason: EditPredictionRejectReason::Canceled,
 971            was_shown: true
 972        }
 973    );
 974
 975    // Reaching batch size limit sends without debounce
 976    ep_store.update(cx, |ep_store, _cx| {
 977        for i in 0..70 {
 978            ep_store.reject_prediction(
 979                EditPredictionId(format!("batch-{}", i).into()),
 980                EditPredictionRejectReason::Discarded,
 981                false,
 982            );
 983        }
 984    });
 985
 986    // First MAX/2 items are sent immediately
 987    cx.run_until_parked();
 988    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
 989    respond_tx.send(()).unwrap();
 990
 991    assert_eq!(reject_request.rejections.len(), 50);
 992    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
 993    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
 994
 995    // Remaining items are debounced with the next batch
 996    cx.executor().advance_clock(Duration::from_secs(15));
 997    cx.run_until_parked();
 998
 999    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1000    respond_tx.send(()).unwrap();
1001
1002    assert_eq!(reject_request.rejections.len(), 20);
1003    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1004    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1005
1006    // Request failure
1007    ep_store.update(cx, |ep_store, _cx| {
1008        ep_store.reject_prediction(
1009            EditPredictionId("retry-1".into()),
1010            EditPredictionRejectReason::Discarded,
1011            false,
1012        );
1013    });
1014
1015    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1016    cx.run_until_parked();
1017
1018    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1019    assert_eq!(reject_request.rejections.len(), 1);
1020    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1021    // Simulate failure
1022    drop(_respond_tx);
1023
1024    // Add another rejection
1025    ep_store.update(cx, |ep_store, _cx| {
1026        ep_store.reject_prediction(
1027            EditPredictionId("retry-2".into()),
1028            EditPredictionRejectReason::Discarded,
1029            false,
1030        );
1031    });
1032
1033    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1034    cx.run_until_parked();
1035
1036    // Retry should include both the failed item and the new one
1037    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1038    respond_tx.send(()).unwrap();
1039
1040    assert_eq!(reject_request.rejections.len(), 2);
1041    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1042    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1043}
1044
1045// Skipped until we start including diagnostics in prompt
1046// #[gpui::test]
1047// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1048//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1049//     let fs = FakeFs::new(cx.executor());
1050//     fs.insert_tree(
1051//         "/root",
1052//         json!({
1053//             "foo.md": "Hello!\nBye"
1054//         }),
1055//     )
1056//     .await;
1057//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1058
1059//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1060//     let diagnostic = lsp::Diagnostic {
1061//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1062//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1063//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1064//         ..Default::default()
1065//     };
1066
1067//     project.update(cx, |project, cx| {
1068//         project.lsp_store().update(cx, |lsp_store, cx| {
1069//             // Create some diagnostics
1070//             lsp_store
1071//                 .update_diagnostics(
1072//                     LanguageServerId(0),
1073//                     lsp::PublishDiagnosticsParams {
1074//                         uri: path_to_buffer_uri.clone(),
1075//                         diagnostics: vec![diagnostic],
1076//                         version: None,
1077//                     },
1078//                     None,
1079//                     language::DiagnosticSourceKind::Pushed,
1080//                     &[],
1081//                     cx,
1082//                 )
1083//                 .unwrap();
1084//         });
1085//     });
1086
1087//     let buffer = project
1088//         .update(cx, |project, cx| {
1089//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1090//             project.open_buffer(path, cx)
1091//         })
1092//         .await
1093//         .unwrap();
1094
1095//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1096//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1097
1098//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1099//         ep_store.request_prediction(&project, &buffer, position, cx)
1100//     });
1101
1102//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1103
1104//     assert_eq!(request.diagnostic_groups.len(), 1);
1105//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1106//         .unwrap();
1107//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1108//     assert_eq!(
1109//         value,
1110//         json!({
1111//             "entries": [{
1112//                 "range": {
1113//                     "start": 8,
1114//                     "end": 10
1115//                 },
1116//                 "diagnostic": {
1117//                     "source": null,
1118//                     "code": null,
1119//                     "code_description": null,
1120//                     "severity": 1,
1121//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1122//                     "markdown": null,
1123//                     "group_id": 0,
1124//                     "is_primary": true,
1125//                     "is_disk_based": false,
1126//                     "is_unnecessary": false,
1127//                     "source_kind": "Pushed",
1128//                     "data": null,
1129//                     "underline": true
1130//                 }
1131//             }],
1132//             "primary_ix": 0
1133//         })
1134//     );
1135// }
1136
1137// Generate a model response that would apply the given diff to the active file.
1138fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
1139    let prompt = match &request.messages[0] {
1140        open_ai::RequestMessage::User {
1141            content: open_ai::MessageContent::Plain(content),
1142        } => content,
1143        _ => panic!("unexpected request {request:?}"),
1144    };
1145
1146    let open = "<editable_region>\n";
1147    let close = "</editable_region>";
1148    let cursor = "<|user_cursor|>";
1149
1150    let start_ix = open.len() + prompt.find(open).unwrap();
1151    let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
1152    let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
1153    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1154
1155    open_ai::Response {
1156        id: Uuid::new_v4().to_string(),
1157        object: "response".into(),
1158        created: 0,
1159        model: "model".into(),
1160        choices: vec![open_ai::Choice {
1161            index: 0,
1162            message: open_ai::RequestMessage::Assistant {
1163                content: Some(open_ai::MessageContent::Plain(new_excerpt)),
1164                tool_calls: vec![],
1165            },
1166            finish_reason: None,
1167        }],
1168        usage: Usage {
1169            prompt_tokens: 0,
1170            completion_tokens: 0,
1171            total_tokens: 0,
1172        },
1173    }
1174}
1175
1176fn prompt_from_request(request: &open_ai::Request) -> &str {
1177    assert_eq!(request.messages.len(), 1);
1178    let open_ai::RequestMessage::User {
1179        content: open_ai::MessageContent::Plain(content),
1180        ..
1181    } = &request.messages[0]
1182    else {
1183        panic!(
1184            "Request does not have single user message of type Plain. {:#?}",
1185            request
1186        );
1187    };
1188    content
1189}
1190
1191struct RequestChannels {
1192    predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
1193    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1194}
1195
1196fn init_test_with_fake_client(
1197    cx: &mut TestAppContext,
1198) -> (Entity<EditPredictionStore>, RequestChannels) {
1199    cx.update(move |cx| {
1200        let settings_store = SettingsStore::test(cx);
1201        cx.set_global(settings_store);
1202        zlog::init_test();
1203
1204        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1205        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1206
1207        let http_client = FakeHttpClient::create({
1208            move |req| {
1209                let uri = req.uri().path().to_string();
1210                let mut body = req.into_body();
1211                let predict_req_tx = predict_req_tx.clone();
1212                let reject_req_tx = reject_req_tx.clone();
1213                async move {
1214                    let resp = match uri.as_str() {
1215                        "/client/llm_tokens" => serde_json::to_string(&json!({
1216                            "token": "test"
1217                        }))
1218                        .unwrap(),
1219                        "/predict_edits/raw" => {
1220                            let mut buf = Vec::new();
1221                            body.read_to_end(&mut buf).await.ok();
1222                            let req = serde_json::from_slice(&buf).unwrap();
1223
1224                            let (res_tx, res_rx) = oneshot::channel();
1225                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1226                            serde_json::to_string(&res_rx.await?).unwrap()
1227                        }
1228                        "/predict_edits/reject" => {
1229                            let mut buf = Vec::new();
1230                            body.read_to_end(&mut buf).await.ok();
1231                            let req = serde_json::from_slice(&buf).unwrap();
1232
1233                            let (res_tx, res_rx) = oneshot::channel();
1234                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1235                            serde_json::to_string(&res_rx.await?).unwrap()
1236                        }
1237                        _ => {
1238                            panic!("Unexpected path: {}", uri)
1239                        }
1240                    };
1241
1242                    Ok(Response::builder().body(resp.into()).unwrap())
1243                }
1244            }
1245        });
1246
1247        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1248        client.cloud_client().set_credentials(1, "test".into());
1249
1250        language_model::init(client.clone(), cx);
1251
1252        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1253        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1254
1255        (
1256            ep_store,
1257            RequestChannels {
1258                predict: predict_req_rx,
1259                reject: reject_req_rx,
1260            },
1261        )
1262    })
1263}
1264
1265const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1266
1267#[gpui::test]
1268async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1269    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1270    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1271        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1272    });
1273
1274    let edit_preview = cx
1275        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1276        .await;
1277
1278    let prediction = EditPrediction {
1279        edits,
1280        edit_preview,
1281        buffer: buffer.clone(),
1282        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1283        id: EditPredictionId("the-id".into()),
1284        inputs: ZetaPromptInput {
1285            events: Default::default(),
1286            related_files: Default::default(),
1287            cursor_path: Path::new("").into(),
1288            cursor_excerpt: "".into(),
1289            editable_range_in_excerpt: 0..0,
1290            cursor_offset_in_excerpt: 0,
1291        },
1292        buffer_snapshotted_at: Instant::now(),
1293        response_received_at: Instant::now(),
1294    };
1295
1296    cx.update(|cx| {
1297        assert_eq!(
1298            from_completion_edits(
1299                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1300                &buffer,
1301                cx
1302            ),
1303            vec![(2..5, "REM".into()), (9..11, "".into())]
1304        );
1305
1306        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1307        assert_eq!(
1308            from_completion_edits(
1309                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1310                &buffer,
1311                cx
1312            ),
1313            vec![(2..2, "REM".into()), (6..8, "".into())]
1314        );
1315
1316        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1317        assert_eq!(
1318            from_completion_edits(
1319                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1320                &buffer,
1321                cx
1322            ),
1323            vec![(2..5, "REM".into()), (9..11, "".into())]
1324        );
1325
1326        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1327        assert_eq!(
1328            from_completion_edits(
1329                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1330                &buffer,
1331                cx
1332            ),
1333            vec![(3..3, "EM".into()), (7..9, "".into())]
1334        );
1335
1336        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1337        assert_eq!(
1338            from_completion_edits(
1339                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1340                &buffer,
1341                cx
1342            ),
1343            vec![(4..4, "M".into()), (8..10, "".into())]
1344        );
1345
1346        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1347        assert_eq!(
1348            from_completion_edits(
1349                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1350                &buffer,
1351                cx
1352            ),
1353            vec![(9..11, "".into())]
1354        );
1355
1356        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1357        assert_eq!(
1358            from_completion_edits(
1359                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1360                &buffer,
1361                cx
1362            ),
1363            vec![(4..4, "M".into()), (8..10, "".into())]
1364        );
1365
1366        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1367        assert_eq!(
1368            from_completion_edits(
1369                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1370                &buffer,
1371                cx
1372            ),
1373            vec![(4..4, "M".into())]
1374        );
1375
1376        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1377        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1378    })
1379}
1380
1381#[gpui::test]
1382async fn test_clean_up_diff(cx: &mut TestAppContext) {
1383    init_test(cx);
1384
1385    assert_eq!(
1386        apply_edit_prediction(
1387            indoc! {"
1388                    fn main() {
1389                        let word_1 = \"lorem\";
1390                        let range = word.len()..word.len();
1391                    }
1392                "},
1393            indoc! {"
1394                    <|editable_region_start|>
1395                    fn main() {
1396                        let word_1 = \"lorem\";
1397                        let range = word_1.len()..word_1.len();
1398                    }
1399
1400                    <|editable_region_end|>
1401                "},
1402            cx,
1403        )
1404        .await,
1405        indoc! {"
1406                fn main() {
1407                    let word_1 = \"lorem\";
1408                    let range = word_1.len()..word_1.len();
1409                }
1410            "},
1411    );
1412
1413    assert_eq!(
1414        apply_edit_prediction(
1415            indoc! {"
1416                    fn main() {
1417                        let story = \"the quick\"
1418                    }
1419                "},
1420            indoc! {"
1421                    <|editable_region_start|>
1422                    fn main() {
1423                        let story = \"the quick brown fox jumps over the lazy dog\";
1424                    }
1425
1426                    <|editable_region_end|>
1427                "},
1428            cx,
1429        )
1430        .await,
1431        indoc! {"
1432                fn main() {
1433                    let story = \"the quick brown fox jumps over the lazy dog\";
1434                }
1435            "},
1436    );
1437}
1438
1439#[gpui::test]
1440async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1441    init_test(cx);
1442
1443    let buffer_content = "lorem\n";
1444    let completion_response = indoc! {"
1445            ```animals.js
1446            <|start_of_file|>
1447            <|editable_region_start|>
1448            lorem
1449            ipsum
1450            <|editable_region_end|>
1451            ```"};
1452
1453    assert_eq!(
1454        apply_edit_prediction(buffer_content, completion_response, cx).await,
1455        "lorem\nipsum"
1456    );
1457}
1458
1459#[gpui::test]
1460async fn test_can_collect_data(cx: &mut TestAppContext) {
1461    init_test(cx);
1462
1463    let fs = project::FakeFs::new(cx.executor());
1464    fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1465        .await;
1466
1467    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1468    let buffer = project
1469        .update(cx, |project, cx| {
1470            project.open_local_buffer(path!("/project/src/main.rs"), cx)
1471        })
1472        .await
1473        .unwrap();
1474
1475    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1476    ep_store.update(cx, |ep_store, _cx| {
1477        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1478    });
1479
1480    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1481    assert_eq!(
1482        captured_request.lock().clone().unwrap().can_collect_data,
1483        true
1484    );
1485
1486    ep_store.update(cx, |ep_store, _cx| {
1487        ep_store.data_collection_choice = DataCollectionChoice::Disabled
1488    });
1489
1490    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1491    assert_eq!(
1492        captured_request.lock().clone().unwrap().can_collect_data,
1493        false
1494    );
1495}
1496
1497#[gpui::test]
1498async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1499    init_test(cx);
1500
1501    let fs = project::FakeFs::new(cx.executor());
1502    let project = Project::test(fs.clone(), [], cx).await;
1503
1504    let buffer = cx.new(|_cx| {
1505        Buffer::remote(
1506            language::BufferId::new(1).unwrap(),
1507            ReplicaId::new(1),
1508            language::Capability::ReadWrite,
1509            "fn main() {\n    println!(\"Hello\");\n}",
1510        )
1511    });
1512
1513    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1514    ep_store.update(cx, |ep_store, _cx| {
1515        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1516    });
1517
1518    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1519    assert_eq!(
1520        captured_request.lock().clone().unwrap().can_collect_data,
1521        false
1522    );
1523}
1524
1525#[gpui::test]
1526async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1527    init_test(cx);
1528
1529    let fs = project::FakeFs::new(cx.executor());
1530    fs.insert_tree(
1531        path!("/project"),
1532        json!({
1533            "LICENSE": BSD_0_TXT,
1534            ".env": "SECRET_KEY=secret"
1535        }),
1536    )
1537    .await;
1538
1539    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1540    let buffer = project
1541        .update(cx, |project, cx| {
1542            project.open_local_buffer("/project/.env", cx)
1543        })
1544        .await
1545        .unwrap();
1546
1547    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1548    ep_store.update(cx, |ep_store, _cx| {
1549        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1550    });
1551
1552    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1553    assert_eq!(
1554        captured_request.lock().clone().unwrap().can_collect_data,
1555        false
1556    );
1557}
1558
1559#[gpui::test]
1560async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1561    init_test(cx);
1562
1563    let fs = project::FakeFs::new(cx.executor());
1564    let project = Project::test(fs.clone(), [], cx).await;
1565    let buffer = cx.new(|cx| Buffer::local("", cx));
1566
1567    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1568    ep_store.update(cx, |ep_store, _cx| {
1569        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1570    });
1571
1572    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1573    assert_eq!(
1574        captured_request.lock().clone().unwrap().can_collect_data,
1575        false
1576    );
1577}
1578
1579#[gpui::test]
1580async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1581    init_test(cx);
1582
1583    let fs = project::FakeFs::new(cx.executor());
1584    fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1585        .await;
1586
1587    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1588    let buffer = project
1589        .update(cx, |project, cx| {
1590            project.open_local_buffer("/project/main.rs", cx)
1591        })
1592        .await
1593        .unwrap();
1594
1595    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1596    ep_store.update(cx, |ep_store, _cx| {
1597        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1598    });
1599
1600    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1601    assert_eq!(
1602        captured_request.lock().clone().unwrap().can_collect_data,
1603        false
1604    );
1605}
1606
1607#[gpui::test]
1608async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1609    init_test(cx);
1610
1611    let fs = project::FakeFs::new(cx.executor());
1612    fs.insert_tree(
1613        path!("/open_source_worktree"),
1614        json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1615    )
1616    .await;
1617    fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1618        .await;
1619
1620    let project = Project::test(
1621        fs.clone(),
1622        [
1623            path!("/open_source_worktree").as_ref(),
1624            path!("/closed_source_worktree").as_ref(),
1625        ],
1626        cx,
1627    )
1628    .await;
1629    let buffer = project
1630        .update(cx, |project, cx| {
1631            project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1632        })
1633        .await
1634        .unwrap();
1635
1636    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1637    ep_store.update(cx, |ep_store, _cx| {
1638        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1639    });
1640
1641    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1642    assert_eq!(
1643        captured_request.lock().clone().unwrap().can_collect_data,
1644        true
1645    );
1646
1647    let closed_source_file = project
1648        .update(cx, |project, cx| {
1649            let worktree2 = project
1650                .worktree_for_root_name("closed_source_worktree", cx)
1651                .unwrap();
1652            worktree2.update(cx, |worktree2, cx| {
1653                worktree2.load_file(rel_path("main.rs"), cx)
1654            })
1655        })
1656        .await
1657        .unwrap()
1658        .file;
1659
1660    buffer.update(cx, |buffer, cx| {
1661        buffer.file_updated(closed_source_file, cx);
1662    });
1663
1664    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1665    assert_eq!(
1666        captured_request.lock().clone().unwrap().can_collect_data,
1667        false
1668    );
1669}
1670
1671#[gpui::test]
1672async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1673    init_test(cx);
1674
1675    let fs = project::FakeFs::new(cx.executor());
1676    fs.insert_tree(
1677        path!("/worktree1"),
1678        json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1679    )
1680    .await;
1681    fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1682        .await;
1683
1684    let project = Project::test(
1685        fs.clone(),
1686        [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1687        cx,
1688    )
1689    .await;
1690    let buffer = project
1691        .update(cx, |project, cx| {
1692            project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1693        })
1694        .await
1695        .unwrap();
1696    let private_buffer = project
1697        .update(cx, |project, cx| {
1698            project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1699        })
1700        .await
1701        .unwrap();
1702
1703    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1704    ep_store.update(cx, |ep_store, _cx| {
1705        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1706    });
1707
1708    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1709    assert_eq!(
1710        captured_request.lock().clone().unwrap().can_collect_data,
1711        true
1712    );
1713
1714    // this has a side effect of registering the buffer to watch for edits
1715    run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1716    assert_eq!(
1717        captured_request.lock().clone().unwrap().can_collect_data,
1718        false
1719    );
1720
1721    private_buffer.update(cx, |private_buffer, cx| {
1722        private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1723    });
1724
1725    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1726    assert_eq!(
1727        captured_request.lock().clone().unwrap().can_collect_data,
1728        false
1729    );
1730
1731    // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1732    // included
1733    buffer.update(cx, |buffer, cx| {
1734        buffer.edit(
1735            [(
1736                0..0,
1737                " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1738            )],
1739            None,
1740            cx,
1741        );
1742    });
1743
1744    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1745    assert_eq!(
1746        captured_request.lock().clone().unwrap().can_collect_data,
1747        true
1748    );
1749}
1750
1751fn init_test(cx: &mut TestAppContext) {
1752    cx.update(|cx| {
1753        let settings_store = SettingsStore::test(cx);
1754        cx.set_global(settings_store);
1755    });
1756}
1757
1758async fn apply_edit_prediction(
1759    buffer_content: &str,
1760    completion_response: &str,
1761    cx: &mut TestAppContext,
1762) -> String {
1763    let fs = project::FakeFs::new(cx.executor());
1764    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1765    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1766    let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1767    *response.lock() = completion_response.to_string();
1768    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1769    buffer.update(cx, |buffer, cx| {
1770        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1771    });
1772    buffer.read_with(cx, |buffer, _| buffer.text())
1773}
1774
1775async fn run_edit_prediction(
1776    buffer: &Entity<Buffer>,
1777    project: &Entity<Project>,
1778    ep_store: &Entity<EditPredictionStore>,
1779    cx: &mut TestAppContext,
1780) -> EditPrediction {
1781    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1782    ep_store.update(cx, |ep_store, cx| {
1783        ep_store.register_buffer(buffer, &project, cx)
1784    });
1785    cx.background_executor.run_until_parked();
1786    let prediction_task = ep_store.update(cx, |ep_store, cx| {
1787        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
1788    });
1789    prediction_task.await.unwrap().unwrap().prediction.unwrap()
1790}
1791
1792async fn make_test_ep_store(
1793    project: &Entity<Project>,
1794    cx: &mut TestAppContext,
1795) -> (
1796    Entity<EditPredictionStore>,
1797    Arc<Mutex<Option<PredictEditsBody>>>,
1798    Arc<Mutex<String>>,
1799) {
1800    let default_response = indoc! {"
1801            ```main.rs
1802            <|start_of_file|>
1803            <|editable_region_start|>
1804            hello world
1805            <|editable_region_end|>
1806            ```"
1807    };
1808    let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
1809    let completion_response: Arc<Mutex<String>> =
1810        Arc::new(Mutex::new(default_response.to_string()));
1811    let http_client = FakeHttpClient::create({
1812        let captured_request = captured_request.clone();
1813        let completion_response = completion_response.clone();
1814        let mut next_request_id = 0;
1815        move |req| {
1816            let captured_request = captured_request.clone();
1817            let completion_response = completion_response.clone();
1818            async move {
1819                match (req.method(), req.uri().path()) {
1820                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1821                        .status(200)
1822                        .body(
1823                            serde_json::to_string(&CreateLlmTokenResponse {
1824                                token: LlmToken("the-llm-token".to_string()),
1825                            })
1826                            .unwrap()
1827                            .into(),
1828                        )
1829                        .unwrap()),
1830                    (&Method::POST, "/predict_edits/v2") => {
1831                        let mut request_body = String::new();
1832                        req.into_body().read_to_string(&mut request_body).await?;
1833                        *captured_request.lock() =
1834                            Some(serde_json::from_str(&request_body).unwrap());
1835                        next_request_id += 1;
1836                        Ok(http_client::Response::builder()
1837                            .status(200)
1838                            .body(
1839                                serde_json::to_string(&PredictEditsResponse {
1840                                    request_id: format!("request-{next_request_id}"),
1841                                    output_excerpt: completion_response.lock().clone(),
1842                                })
1843                                .unwrap()
1844                                .into(),
1845                            )
1846                            .unwrap())
1847                    }
1848                    _ => Ok(http_client::Response::builder()
1849                        .status(404)
1850                        .body("Not Found".into())
1851                        .unwrap()),
1852                }
1853            }
1854        }
1855    });
1856
1857    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1858    cx.update(|cx| {
1859        RefreshLlmTokenListener::register(client.clone(), cx);
1860    });
1861    let _server = FakeServer::for_client(42, &client, cx).await;
1862
1863    let ep_store = cx.new(|cx| {
1864        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
1865        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
1866
1867        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
1868        for worktree in worktrees {
1869            let worktree_id = worktree.read(cx).id();
1870            ep_store
1871                .get_or_init_project(project, cx)
1872                .license_detection_watchers
1873                .entry(worktree_id)
1874                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
1875        }
1876
1877        ep_store
1878    });
1879
1880    (ep_store, captured_request, completion_response)
1881}
1882
1883fn to_completion_edits(
1884    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
1885    buffer: &Entity<Buffer>,
1886    cx: &App,
1887) -> Vec<(Range<Anchor>, Arc<str>)> {
1888    let buffer = buffer.read(cx);
1889    iterator
1890        .into_iter()
1891        .map(|(range, text)| {
1892            (
1893                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1894                text,
1895            )
1896        })
1897        .collect()
1898}
1899
1900fn from_completion_edits(
1901    editor_edits: &[(Range<Anchor>, Arc<str>)],
1902    buffer: &Entity<Buffer>,
1903    cx: &App,
1904) -> Vec<(Range<usize>, Arc<str>)> {
1905    let buffer = buffer.read(cx);
1906    editor_edits
1907        .iter()
1908        .map(|(range, text)| {
1909            (
1910                range.start.to_offset(buffer)..range.end.to_offset(buffer),
1911                text.clone(),
1912            )
1913        })
1914        .collect()
1915}
1916
1917#[ctor::ctor]
1918fn init_logger() {
1919    zlog::init_test();
1920}