edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
   3use client::{UserStore, test::FakeServer};
   4use clock::FakeSystemClock;
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
   8    predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
   9};
  10use futures::{
  11    AsyncReadExt, FutureExt, StreamExt,
  12    channel::{mpsc, oneshot},
  13};
  14use gpui::App;
  15use gpui::{
  16    Entity, TestAppContext,
  17    http_client::{FakeHttpClient, Response},
  18};
  19use indoc::indoc;
  20use language::{Anchor, Buffer, CursorShape, Operation, Point, Selection, SelectionGoal};
  21use lsp::LanguageServerId;
  22use parking_lot::Mutex;
  23use pretty_assertions::{assert_eq, assert_matches};
  24use project::{FakeFs, Project};
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{path::Path, sync::Arc, time::Duration};
  28use util::path;
  29use uuid::Uuid;
  30use zeta_prompt::ZetaPromptInput;
  31
  32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  33
  34#[gpui::test]
  35async fn test_current_state(cx: &mut TestAppContext) {
  36    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  37    let fs = FakeFs::new(cx.executor());
  38    fs.insert_tree(
  39        "/root",
  40        json!({
  41            "1.txt": "Hello!\nHow\nBye\n",
  42            "2.txt": "Hola!\nComo\nAdios\n"
  43        }),
  44    )
  45    .await;
  46    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  47
  48    let buffer1 = project
  49        .update(cx, |project, cx| {
  50            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  51            project.set_active_path(Some(path.clone()), cx);
  52            project.open_buffer(path, cx)
  53        })
  54        .await
  55        .unwrap();
  56    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  57    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  58
  59    ep_store.update(cx, |ep_store, cx| {
  60        ep_store.register_project(&project, cx);
  61        ep_store.register_buffer(&buffer1, &project, cx);
  62    });
  63
  64    // Prediction for current file
  65
  66    ep_store.update(cx, |ep_store, cx| {
  67        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  68    });
  69    let (request, respond_tx) = requests.predict.next().await.unwrap();
  70
  71    respond_tx
  72        .send(model_response(
  73            &request,
  74            indoc! {r"
  75                --- a/root/1.txt
  76                +++ b/root/1.txt
  77                @@ ... @@
  78                 Hello!
  79                -How
  80                +How are you?
  81                 Bye
  82            "},
  83        ))
  84        .unwrap();
  85
  86    cx.run_until_parked();
  87
  88    ep_store.update(cx, |ep_store, cx| {
  89        let prediction = ep_store
  90            .prediction_at(&buffer1, None, &project, cx)
  91            .unwrap();
  92        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  93    });
  94
  95    ep_store.update(cx, |ep_store, cx| {
  96        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
  97    });
  98
  99    // Prediction for diagnostic in another file
 100
 101    let diagnostic = lsp::Diagnostic {
 102        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 103        severity: Some(lsp::DiagnosticSeverity::ERROR),
 104        message: "Sentence is incomplete".to_string(),
 105        ..Default::default()
 106    };
 107
 108    project.update(cx, |project, cx| {
 109        project.lsp_store().update(cx, |lsp_store, cx| {
 110            lsp_store
 111                .update_diagnostics(
 112                    LanguageServerId(0),
 113                    lsp::PublishDiagnosticsParams {
 114                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 115                        diagnostics: vec![diagnostic],
 116                        version: None,
 117                    },
 118                    None,
 119                    language::DiagnosticSourceKind::Pushed,
 120                    &[],
 121                    cx,
 122                )
 123                .unwrap();
 124        });
 125    });
 126
 127    let (request, respond_tx) = requests.predict.next().await.unwrap();
 128    respond_tx
 129        .send(model_response(
 130            &request,
 131            indoc! {r#"
 132                --- a/root/2.txt
 133                +++ b/root/2.txt
 134                @@ ... @@
 135                 Hola!
 136                -Como
 137                +Como estas?
 138                 Adios
 139            "#},
 140        ))
 141        .unwrap();
 142    cx.run_until_parked();
 143
 144    ep_store.update(cx, |ep_store, cx| {
 145        let prediction = ep_store
 146            .prediction_at(&buffer1, None, &project, cx)
 147            .unwrap();
 148        assert_matches!(
 149            prediction,
 150            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 151        );
 152    });
 153
 154    let buffer2 = project
 155        .update(cx, |project, cx| {
 156            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 157            project.open_buffer(path, cx)
 158        })
 159        .await
 160        .unwrap();
 161
 162    ep_store.update(cx, |ep_store, cx| {
 163        let prediction = ep_store
 164            .prediction_at(&buffer2, None, &project, cx)
 165            .unwrap();
 166        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 167    });
 168}
 169
 170#[gpui::test]
 171async fn test_simple_request(cx: &mut TestAppContext) {
 172    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 173    let fs = FakeFs::new(cx.executor());
 174    fs.insert_tree(
 175        "/root",
 176        json!({
 177            "foo.md":  "Hello!\nHow\nBye\n"
 178        }),
 179    )
 180    .await;
 181    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 182
 183    let buffer = project
 184        .update(cx, |project, cx| {
 185            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 186            project.open_buffer(path, cx)
 187        })
 188        .await
 189        .unwrap();
 190    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 191    let position = snapshot.anchor_before(language::Point::new(1, 3));
 192
 193    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 194        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 195    });
 196
 197    let (request, respond_tx) = requests.predict.next().await.unwrap();
 198
 199    // TODO Put back when we have a structured request again
 200    // assert_eq!(
 201    //     request.excerpt_path.as_ref(),
 202    //     Path::new(path!("root/foo.md"))
 203    // );
 204    // assert_eq!(
 205    //     request.cursor_point,
 206    //     Point {
 207    //         line: Line(1),
 208    //         column: 3
 209    //     }
 210    // );
 211
 212    respond_tx
 213        .send(model_response(
 214            &request,
 215            indoc! { r"
 216                --- a/root/foo.md
 217                +++ b/root/foo.md
 218                @@ ... @@
 219                 Hello!
 220                -How
 221                +How are you?
 222                 Bye
 223            "},
 224        ))
 225        .unwrap();
 226
 227    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 228
 229    assert_eq!(prediction.edits.len(), 1);
 230    assert_eq!(
 231        prediction.edits[0].0.to_point(&snapshot).start,
 232        language::Point::new(1, 3)
 233    );
 234    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 235}
 236
 237#[gpui::test]
 238async fn test_request_events(cx: &mut TestAppContext) {
 239    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 240    let fs = FakeFs::new(cx.executor());
 241    fs.insert_tree(
 242        "/root",
 243        json!({
 244            "foo.md": "Hello!\n\nBye\n"
 245        }),
 246    )
 247    .await;
 248    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 249
 250    let buffer = project
 251        .update(cx, |project, cx| {
 252            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 253            project.open_buffer(path, cx)
 254        })
 255        .await
 256        .unwrap();
 257
 258    ep_store.update(cx, |ep_store, cx| {
 259        ep_store.register_buffer(&buffer, &project, cx);
 260    });
 261
 262    buffer.update(cx, |buffer, cx| {
 263        buffer.edit(vec![(7..7, "How")], None, cx);
 264    });
 265
 266    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 267    let position = snapshot.anchor_before(language::Point::new(1, 3));
 268
 269    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 270        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 271    });
 272
 273    let (request, respond_tx) = requests.predict.next().await.unwrap();
 274
 275    let prompt = prompt_from_request(&request);
 276    assert!(
 277        prompt.contains(indoc! {"
 278        --- a/root/foo.md
 279        +++ b/root/foo.md
 280        @@ -1,3 +1,3 @@
 281         Hello!
 282        -
 283        +How
 284         Bye
 285    "}),
 286        "{prompt}"
 287    );
 288
 289    respond_tx
 290        .send(model_response(
 291            &request,
 292            indoc! {r#"
 293                --- a/root/foo.md
 294                +++ b/root/foo.md
 295                @@ ... @@
 296                 Hello!
 297                -How
 298                +How are you?
 299                 Bye
 300        "#},
 301        ))
 302        .unwrap();
 303
 304    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 305
 306    assert_eq!(prediction.edits.len(), 1);
 307    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 308}
 309
 310#[gpui::test]
 311async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 312    let (ep_store, _requests) = init_test_with_fake_client(cx);
 313    let fs = FakeFs::new(cx.executor());
 314    fs.insert_tree(
 315        "/root",
 316        json!({
 317            "foo.md": "Hello!\n\nBye\n"
 318        }),
 319    )
 320    .await;
 321    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 322
 323    let buffer = project
 324        .update(cx, |project, cx| {
 325            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 326            project.open_buffer(path, cx)
 327        })
 328        .await
 329        .unwrap();
 330
 331    ep_store.update(cx, |ep_store, cx| {
 332        ep_store.register_buffer(&buffer, &project, cx);
 333    });
 334
 335    // First burst: insert "How"
 336    buffer.update(cx, |buffer, cx| {
 337        buffer.edit(vec![(7..7, "How")], None, cx);
 338    });
 339
 340    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 341    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 342    cx.run_until_parked();
 343
 344    // Second burst: append " are you?" immediately after "How" on the same line.
 345    //
 346    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 347    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 348    buffer.update(cx, |buffer, cx| {
 349        buffer.edit(vec![(10..10, " are you?")], None, cx);
 350    });
 351
 352    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 353    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 354    buffer.update(cx, |buffer, cx| {
 355        buffer.edit(vec![(19..19, "!")], None, cx);
 356    });
 357
 358    // With time-based splitting, there are two distinct events.
 359    let events = ep_store.update(cx, |ep_store, cx| {
 360        ep_store.edit_history_for_project(&project, cx)
 361    });
 362    assert_eq!(events.len(), 2);
 363    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 364    assert_eq!(
 365        diff.as_str(),
 366        indoc! {"
 367            @@ -1,3 +1,3 @@
 368             Hello!
 369            -
 370            +How
 371             Bye
 372        "}
 373    );
 374
 375    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 376    assert_eq!(
 377        diff.as_str(),
 378        indoc! {"
 379            @@ -1,3 +1,3 @@
 380             Hello!
 381            -How
 382            +How are you?!
 383             Bye
 384        "}
 385    );
 386}
 387
 388#[gpui::test]
 389async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
 390    let (ep_store, _requests) = init_test_with_fake_client(cx);
 391    let fs = FakeFs::new(cx.executor());
 392
 393    // Create a file with 30 lines to test line-based coalescing
 394    let content = (1..=30)
 395        .map(|i| format!("Line {}\n", i))
 396        .collect::<String>();
 397    fs.insert_tree(
 398        "/root",
 399        json!({
 400            "foo.md": content
 401        }),
 402    )
 403    .await;
 404    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 405
 406    let buffer = project
 407        .update(cx, |project, cx| {
 408            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 409            project.open_buffer(path, cx)
 410        })
 411        .await
 412        .unwrap();
 413
 414    ep_store.update(cx, |ep_store, cx| {
 415        ep_store.register_buffer(&buffer, &project, cx);
 416    });
 417
 418    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 419    buffer.update(cx, |buffer, cx| {
 420        let start = Point::new(10, 0).to_offset(buffer);
 421        let end = Point::new(13, 0).to_offset(buffer);
 422        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 423    });
 424
 425    let events = ep_store.update(cx, |ep_store, cx| {
 426        ep_store.edit_history_for_project(&project, cx)
 427    });
 428    assert_eq!(
 429        render_events(&events),
 430        indoc! {"
 431            @@ -8,9 +8,8 @@
 432             Line 8
 433             Line 9
 434             Line 10
 435            -Line 11
 436            -Line 12
 437            -Line 13
 438            +Middle A
 439            +Middle B
 440             Line 14
 441             Line 15
 442             Line 16
 443        "},
 444        "After first edit"
 445    );
 446
 447    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 448    // This tests that coalescing considers the START of the existing range
 449    buffer.update(cx, |buffer, cx| {
 450        let offset = Point::new(5, 0).to_offset(buffer);
 451        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 452    });
 453
 454    let events = ep_store.update(cx, |ep_store, cx| {
 455        ep_store.edit_history_for_project(&project, cx)
 456    });
 457    assert_eq!(
 458        render_events(&events),
 459        indoc! {"
 460            @@ -3,14 +3,14 @@
 461             Line 3
 462             Line 4
 463             Line 5
 464            +Above
 465             Line 6
 466             Line 7
 467             Line 8
 468             Line 9
 469             Line 10
 470            -Line 11
 471            -Line 12
 472            -Line 13
 473            +Middle A
 474            +Middle B
 475             Line 14
 476             Line 15
 477             Line 16
 478        "},
 479        "After inserting above (should coalesce)"
 480    );
 481
 482    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 483    // This tests that coalescing considers the END of the existing range
 484    buffer.update(cx, |buffer, cx| {
 485        let offset = Point::new(14, 0).to_offset(buffer);
 486        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 487    });
 488
 489    let events = ep_store.update(cx, |ep_store, cx| {
 490        ep_store.edit_history_for_project(&project, cx)
 491    });
 492    assert_eq!(
 493        render_events(&events),
 494        indoc! {"
 495            @@ -3,15 +3,16 @@
 496             Line 3
 497             Line 4
 498             Line 5
 499            +Above
 500             Line 6
 501             Line 7
 502             Line 8
 503             Line 9
 504             Line 10
 505            -Line 11
 506            -Line 12
 507            -Line 13
 508            +Middle A
 509            +Middle B
 510             Line 14
 511            +Below
 512             Line 15
 513             Line 16
 514             Line 17
 515        "},
 516        "After inserting below (should coalesce)"
 517    );
 518
 519    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 520    // This should NOT coalesce - creates a new event
 521    buffer.update(cx, |buffer, cx| {
 522        let offset = Point::new(25, 0).to_offset(buffer);
 523        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 524    });
 525
 526    let events = ep_store.update(cx, |ep_store, cx| {
 527        ep_store.edit_history_for_project(&project, cx)
 528    });
 529    assert_eq!(
 530        render_events(&events),
 531        indoc! {"
 532            @@ -3,15 +3,16 @@
 533             Line 3
 534             Line 4
 535             Line 5
 536            +Above
 537             Line 6
 538             Line 7
 539             Line 8
 540             Line 9
 541             Line 10
 542            -Line 11
 543            -Line 12
 544            -Line 13
 545            +Middle A
 546            +Middle B
 547             Line 14
 548            +Below
 549             Line 15
 550             Line 16
 551             Line 17
 552
 553            ---
 554            @@ -23,6 +23,7 @@
 555             Line 22
 556             Line 23
 557             Line 24
 558            +Far below
 559             Line 25
 560             Line 26
 561             Line 27
 562        "},
 563        "After inserting far below (should NOT coalesce)"
 564    );
 565}
 566
 567fn render_events(events: &[StoredEvent]) -> String {
 568    events
 569        .iter()
 570        .map(|e| {
 571            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 572            diff.as_str()
 573        })
 574        .collect::<Vec<_>>()
 575        .join("\n---\n")
 576}
 577
 578fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
 579    events
 580        .iter()
 581        .map(|e| {
 582            let zeta_prompt::Event::BufferChange {
 583                diff, predicted, ..
 584            } = e.event.as_ref();
 585            let prefix = if *predicted { "predicted" } else { "manual" };
 586            format!("{}\n{}", prefix, diff)
 587        })
 588        .collect()
 589}
 590
 591#[gpui::test]
 592async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
 593    let (ep_store, _requests) = init_test_with_fake_client(cx);
 594    let fs = FakeFs::new(cx.executor());
 595    fs.insert_tree(
 596        "/root",
 597        json!({
 598            "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
 599        }),
 600    )
 601    .await;
 602    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 603
 604    let buffer = project
 605        .update(cx, |project, cx| {
 606            let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
 607            project.open_buffer(path, cx)
 608        })
 609        .await
 610        .unwrap();
 611
 612    ep_store.update(cx, |ep_store, cx| {
 613        ep_store.register_buffer(&buffer, &project, cx);
 614    });
 615
 616    // Case 1: Manual edits have `predicted` set to false.
 617    buffer.update(cx, |buffer, cx| {
 618        buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
 619    });
 620
 621    let events = ep_store.update(cx, |ep_store, cx| {
 622        ep_store.edit_history_for_project(&project, cx)
 623    });
 624
 625    assert_eq!(
 626        render_events_with_predicted(&events),
 627        vec![indoc! {"
 628            manual
 629            @@ -1,4 +1,4 @@
 630            -line 0
 631            +LINE ZERO
 632             line 1
 633             line 2
 634             line 3
 635        "}]
 636    );
 637
 638    // Case 2: Multiple successive manual edits near each other are merged into one
 639    // event with `predicted` set to false.
 640    buffer.update(cx, |buffer, cx| {
 641        let offset = Point::new(1, 0).to_offset(buffer);
 642        let end = Point::new(1, 6).to_offset(buffer);
 643        buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
 644    });
 645
 646    let events = ep_store.update(cx, |ep_store, cx| {
 647        ep_store.edit_history_for_project(&project, cx)
 648    });
 649    assert_eq!(
 650        render_events_with_predicted(&events),
 651        vec![indoc! {"
 652            manual
 653            @@ -1,5 +1,5 @@
 654            -line 0
 655            -line 1
 656            +LINE ZERO
 657            +LINE ONE
 658             line 2
 659             line 3
 660             line 4
 661        "}]
 662    );
 663
 664    // Case 3: Accepted predictions have `predicted` set to true.
 665    // Case 5: A manual edit that follows a predicted edit is not merged with the
 666    // predicted edit, even if it is nearby.
 667    ep_store.update(cx, |ep_store, cx| {
 668        buffer.update(cx, |buffer, cx| {
 669            let offset = Point::new(2, 0).to_offset(buffer);
 670            let end = Point::new(2, 6).to_offset(buffer);
 671            buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
 672        });
 673        ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
 674    });
 675
 676    let events = ep_store.update(cx, |ep_store, cx| {
 677        ep_store.edit_history_for_project(&project, cx)
 678    });
 679    assert_eq!(
 680        render_events_with_predicted(&events),
 681        vec![
 682            indoc! {"
 683                manual
 684                @@ -1,5 +1,5 @@
 685                -line 0
 686                -line 1
 687                +LINE ZERO
 688                +LINE ONE
 689                 line 2
 690                 line 3
 691                 line 4
 692            "},
 693            indoc! {"
 694                predicted
 695                @@ -1,6 +1,6 @@
 696                 LINE ZERO
 697                 LINE ONE
 698                -line 2
 699                +LINE TWO
 700                 line 3
 701                 line 4
 702                 line 5
 703            "}
 704        ]
 705    );
 706
 707    // Case 4: Multiple successive accepted predictions near each other are merged
 708    // into one event with `predicted` set to true.
 709    ep_store.update(cx, |ep_store, cx| {
 710        buffer.update(cx, |buffer, cx| {
 711            let offset = Point::new(3, 0).to_offset(buffer);
 712            let end = Point::new(3, 6).to_offset(buffer);
 713            buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
 714        });
 715        ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
 716    });
 717
 718    let events = ep_store.update(cx, |ep_store, cx| {
 719        ep_store.edit_history_for_project(&project, cx)
 720    });
 721    assert_eq!(
 722        render_events_with_predicted(&events),
 723        vec![
 724            indoc! {"
 725                manual
 726                @@ -1,5 +1,5 @@
 727                -line 0
 728                -line 1
 729                +LINE ZERO
 730                +LINE ONE
 731                 line 2
 732                 line 3
 733                 line 4
 734            "},
 735            indoc! {"
 736                predicted
 737                @@ -1,7 +1,7 @@
 738                 LINE ZERO
 739                 LINE ONE
 740                -line 2
 741                -line 3
 742                +LINE TWO
 743                +LINE THREE
 744                 line 4
 745                 line 5
 746                 line 6
 747            "}
 748        ]
 749    );
 750
 751    // Case 5 (continued): A manual edit that follows a predicted edit is not merged
 752    // with the predicted edit, even if it is nearby.
 753    buffer.update(cx, |buffer, cx| {
 754        let offset = Point::new(4, 0).to_offset(buffer);
 755        let end = Point::new(4, 6).to_offset(buffer);
 756        buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
 757    });
 758
 759    let events = ep_store.update(cx, |ep_store, cx| {
 760        ep_store.edit_history_for_project(&project, cx)
 761    });
 762    assert_eq!(
 763        render_events_with_predicted(&events),
 764        vec![
 765            indoc! {"
 766                manual
 767                @@ -1,5 +1,5 @@
 768                -line 0
 769                -line 1
 770                +LINE ZERO
 771                +LINE ONE
 772                 line 2
 773                 line 3
 774                 line 4
 775            "},
 776            indoc! {"
 777                predicted
 778                @@ -1,7 +1,7 @@
 779                 LINE ZERO
 780                 LINE ONE
 781                -line 2
 782                -line 3
 783                +LINE TWO
 784                +LINE THREE
 785                 line 4
 786                 line 5
 787                 line 6
 788            "},
 789            indoc! {"
 790                manual
 791                @@ -2,7 +2,7 @@
 792                 LINE ONE
 793                 LINE TWO
 794                 LINE THREE
 795                -line 4
 796                +LINE FOUR
 797                 line 5
 798                 line 6
 799                 line 7
 800            "}
 801        ]
 802    );
 803
 804    // Case 6: If we then perform a manual edit at a *different* location (more than
 805    // 8 lines away), then the edits at the prior location can be merged with each
 806    // other, even if some are predicted and some are not. `predicted` means all
 807    // constituent edits were predicted.
 808    buffer.update(cx, |buffer, cx| {
 809        let offset = Point::new(14, 0).to_offset(buffer);
 810        let end = Point::new(14, 7).to_offset(buffer);
 811        buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
 812    });
 813
 814    let events = ep_store.update(cx, |ep_store, cx| {
 815        ep_store.edit_history_for_project(&project, cx)
 816    });
 817    assert_eq!(
 818        render_events_with_predicted(&events),
 819        vec![
 820            indoc! {"
 821                manual
 822                @@ -1,8 +1,8 @@
 823                -line 0
 824                -line 1
 825                -line 2
 826                -line 3
 827                -line 4
 828                +LINE ZERO
 829                +LINE ONE
 830                +LINE TWO
 831                +LINE THREE
 832                +LINE FOUR
 833                 line 5
 834                 line 6
 835                 line 7
 836            "},
 837            indoc! {"
 838                manual
 839                @@ -12,4 +12,4 @@
 840                 line 11
 841                 line 12
 842                 line 13
 843                -line 14
 844                +LINE FOURTEEN
 845            "}
 846        ]
 847    );
 848}
 849
 850#[gpui::test]
 851async fn test_empty_prediction(cx: &mut TestAppContext) {
 852    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 853    let fs = FakeFs::new(cx.executor());
 854    fs.insert_tree(
 855        "/root",
 856        json!({
 857            "foo.md":  "Hello!\nHow\nBye\n"
 858        }),
 859    )
 860    .await;
 861    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 862
 863    let buffer = project
 864        .update(cx, |project, cx| {
 865            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 866            project.open_buffer(path, cx)
 867        })
 868        .await
 869        .unwrap();
 870    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 871    let position = snapshot.anchor_before(language::Point::new(1, 3));
 872
 873    ep_store.update(cx, |ep_store, cx| {
 874        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 875    });
 876
 877    let (request, respond_tx) = requests.predict.next().await.unwrap();
 878    let response = model_response(&request, "");
 879    let id = response.request_id.clone();
 880    respond_tx.send(response).unwrap();
 881
 882    cx.run_until_parked();
 883
 884    ep_store.update(cx, |ep_store, cx| {
 885        assert!(
 886            ep_store
 887                .prediction_at(&buffer, None, &project, cx)
 888                .is_none()
 889        );
 890    });
 891
 892    // prediction is reported as rejected
 893    let (reject_request, _) = requests.reject.next().await.unwrap();
 894
 895    assert_eq!(
 896        &reject_request.rejections,
 897        &[EditPredictionRejection {
 898            request_id: id,
 899            reason: EditPredictionRejectReason::Empty,
 900            was_shown: false,
 901            model_version: None,
 902        }]
 903    );
 904}
 905
 906#[gpui::test]
 907async fn test_interpolated_empty(cx: &mut TestAppContext) {
 908    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 909    let fs = FakeFs::new(cx.executor());
 910    fs.insert_tree(
 911        "/root",
 912        json!({
 913            "foo.md":  "Hello!\nHow\nBye\n"
 914        }),
 915    )
 916    .await;
 917    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 918
 919    let buffer = project
 920        .update(cx, |project, cx| {
 921            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 922            project.open_buffer(path, cx)
 923        })
 924        .await
 925        .unwrap();
 926    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 927    let position = snapshot.anchor_before(language::Point::new(1, 3));
 928
 929    ep_store.update(cx, |ep_store, cx| {
 930        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 931    });
 932
 933    let (request, respond_tx) = requests.predict.next().await.unwrap();
 934
 935    buffer.update(cx, |buffer, cx| {
 936        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 937    });
 938
 939    let response = model_response(&request, SIMPLE_DIFF);
 940    let id = response.request_id.clone();
 941    respond_tx.send(response).unwrap();
 942
 943    cx.run_until_parked();
 944
 945    ep_store.update(cx, |ep_store, cx| {
 946        assert!(
 947            ep_store
 948                .prediction_at(&buffer, None, &project, cx)
 949                .is_none()
 950        );
 951    });
 952
 953    // prediction is reported as rejected
 954    let (reject_request, _) = requests.reject.next().await.unwrap();
 955
 956    assert_eq!(
 957        &reject_request.rejections,
 958        &[EditPredictionRejection {
 959            request_id: id,
 960            reason: EditPredictionRejectReason::InterpolatedEmpty,
 961            was_shown: false,
 962            model_version: None,
 963        }]
 964    );
 965}
 966
 967const SIMPLE_DIFF: &str = indoc! { r"
 968    --- a/root/foo.md
 969    +++ b/root/foo.md
 970    @@ ... @@
 971     Hello!
 972    -How
 973    +How are you?
 974     Bye
 975"};
 976
 977#[gpui::test]
 978async fn test_replace_current(cx: &mut TestAppContext) {
 979    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 980    let fs = FakeFs::new(cx.executor());
 981    fs.insert_tree(
 982        "/root",
 983        json!({
 984            "foo.md":  "Hello!\nHow\nBye\n"
 985        }),
 986    )
 987    .await;
 988    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 989
 990    let buffer = project
 991        .update(cx, |project, cx| {
 992            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 993            project.open_buffer(path, cx)
 994        })
 995        .await
 996        .unwrap();
 997    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 998    let position = snapshot.anchor_before(language::Point::new(1, 3));
 999
1000    ep_store.update(cx, |ep_store, cx| {
1001        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1002    });
1003
1004    let (request, respond_tx) = requests.predict.next().await.unwrap();
1005    let first_response = model_response(&request, SIMPLE_DIFF);
1006    let first_id = first_response.request_id.clone();
1007    respond_tx.send(first_response).unwrap();
1008
1009    cx.run_until_parked();
1010
1011    ep_store.update(cx, |ep_store, cx| {
1012        assert_eq!(
1013            ep_store
1014                .prediction_at(&buffer, None, &project, cx)
1015                .unwrap()
1016                .id
1017                .0,
1018            first_id
1019        );
1020    });
1021
1022    // a second request is triggered
1023    ep_store.update(cx, |ep_store, cx| {
1024        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1025    });
1026
1027    let (request, respond_tx) = requests.predict.next().await.unwrap();
1028    let second_response = model_response(&request, SIMPLE_DIFF);
1029    let second_id = second_response.request_id.clone();
1030    respond_tx.send(second_response).unwrap();
1031
1032    cx.run_until_parked();
1033
1034    ep_store.update(cx, |ep_store, cx| {
1035        // second replaces first
1036        assert_eq!(
1037            ep_store
1038                .prediction_at(&buffer, None, &project, cx)
1039                .unwrap()
1040                .id
1041                .0,
1042            second_id
1043        );
1044    });
1045
1046    // first is reported as replaced
1047    let (reject_request, _) = requests.reject.next().await.unwrap();
1048
1049    assert_eq!(
1050        &reject_request.rejections,
1051        &[EditPredictionRejection {
1052            request_id: first_id,
1053            reason: EditPredictionRejectReason::Replaced,
1054            was_shown: false,
1055            model_version: None,
1056        }]
1057    );
1058}
1059
1060#[gpui::test]
1061async fn test_current_preferred(cx: &mut TestAppContext) {
1062    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1063    let fs = FakeFs::new(cx.executor());
1064    fs.insert_tree(
1065        "/root",
1066        json!({
1067            "foo.md":  "Hello!\nHow\nBye\n"
1068        }),
1069    )
1070    .await;
1071    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1072
1073    let buffer = project
1074        .update(cx, |project, cx| {
1075            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1076            project.open_buffer(path, cx)
1077        })
1078        .await
1079        .unwrap();
1080    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1081    let position = snapshot.anchor_before(language::Point::new(1, 3));
1082
1083    ep_store.update(cx, |ep_store, cx| {
1084        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1085    });
1086
1087    let (request, respond_tx) = requests.predict.next().await.unwrap();
1088    let first_response = model_response(&request, SIMPLE_DIFF);
1089    let first_id = first_response.request_id.clone();
1090    respond_tx.send(first_response).unwrap();
1091
1092    cx.run_until_parked();
1093
1094    ep_store.update(cx, |ep_store, cx| {
1095        assert_eq!(
1096            ep_store
1097                .prediction_at(&buffer, None, &project, cx)
1098                .unwrap()
1099                .id
1100                .0,
1101            first_id
1102        );
1103    });
1104
1105    // a second request is triggered
1106    ep_store.update(cx, |ep_store, cx| {
1107        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1108    });
1109
1110    let (request, respond_tx) = requests.predict.next().await.unwrap();
1111    // worse than current prediction
1112    let second_response = model_response(
1113        &request,
1114        indoc! { r"
1115            --- a/root/foo.md
1116            +++ b/root/foo.md
1117            @@ ... @@
1118             Hello!
1119            -How
1120            +How are
1121             Bye
1122        "},
1123    );
1124    let second_id = second_response.request_id.clone();
1125    respond_tx.send(second_response).unwrap();
1126
1127    cx.run_until_parked();
1128
1129    ep_store.update(cx, |ep_store, cx| {
1130        // first is preferred over second
1131        assert_eq!(
1132            ep_store
1133                .prediction_at(&buffer, None, &project, cx)
1134                .unwrap()
1135                .id
1136                .0,
1137            first_id
1138        );
1139    });
1140
1141    // second is reported as rejected
1142    let (reject_request, _) = requests.reject.next().await.unwrap();
1143
1144    assert_eq!(
1145        &reject_request.rejections,
1146        &[EditPredictionRejection {
1147            request_id: second_id,
1148            reason: EditPredictionRejectReason::CurrentPreferred,
1149            was_shown: false,
1150            model_version: None,
1151        }]
1152    );
1153}
1154
1155#[gpui::test]
1156async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1157    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1158    let fs = FakeFs::new(cx.executor());
1159    fs.insert_tree(
1160        "/root",
1161        json!({
1162            "foo.md":  "Hello!\nHow\nBye\n"
1163        }),
1164    )
1165    .await;
1166    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1167
1168    let buffer = project
1169        .update(cx, |project, cx| {
1170            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1171            project.open_buffer(path, cx)
1172        })
1173        .await
1174        .unwrap();
1175    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1176    let position = snapshot.anchor_before(language::Point::new(1, 3));
1177
1178    // start two refresh tasks
1179    ep_store.update(cx, |ep_store, cx| {
1180        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1181    });
1182
1183    let (request1, respond_first) = requests.predict.next().await.unwrap();
1184
1185    ep_store.update(cx, |ep_store, cx| {
1186        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1187    });
1188
1189    let (request, respond_second) = requests.predict.next().await.unwrap();
1190
1191    // wait for throttle
1192    cx.run_until_parked();
1193
1194    // second responds first
1195    let second_response = model_response(&request, SIMPLE_DIFF);
1196    let second_id = second_response.request_id.clone();
1197    respond_second.send(second_response).unwrap();
1198
1199    cx.run_until_parked();
1200
1201    ep_store.update(cx, |ep_store, cx| {
1202        // current prediction is second
1203        assert_eq!(
1204            ep_store
1205                .prediction_at(&buffer, None, &project, cx)
1206                .unwrap()
1207                .id
1208                .0,
1209            second_id
1210        );
1211    });
1212
1213    let first_response = model_response(&request1, SIMPLE_DIFF);
1214    let first_id = first_response.request_id.clone();
1215    respond_first.send(first_response).unwrap();
1216
1217    cx.run_until_parked();
1218
1219    ep_store.update(cx, |ep_store, cx| {
1220        // current prediction is still second, since first was cancelled
1221        assert_eq!(
1222            ep_store
1223                .prediction_at(&buffer, None, &project, cx)
1224                .unwrap()
1225                .id
1226                .0,
1227            second_id
1228        );
1229    });
1230
1231    // first is reported as rejected
1232    let (reject_request, _) = requests.reject.next().await.unwrap();
1233
1234    cx.run_until_parked();
1235
1236    assert_eq!(
1237        &reject_request.rejections,
1238        &[EditPredictionRejection {
1239            request_id: first_id,
1240            reason: EditPredictionRejectReason::Canceled,
1241            was_shown: false,
1242            model_version: None,
1243        }]
1244    );
1245}
1246
1247#[gpui::test]
1248async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1249    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1250    let fs = FakeFs::new(cx.executor());
1251    fs.insert_tree(
1252        "/root",
1253        json!({
1254            "foo.md":  "Hello!\nHow\nBye\n"
1255        }),
1256    )
1257    .await;
1258    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1259
1260    let buffer = project
1261        .update(cx, |project, cx| {
1262            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1263            project.open_buffer(path, cx)
1264        })
1265        .await
1266        .unwrap();
1267    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1268    let position = snapshot.anchor_before(language::Point::new(1, 3));
1269
1270    // start two refresh tasks
1271    ep_store.update(cx, |ep_store, cx| {
1272        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1273    });
1274
1275    let (request1, respond_first) = requests.predict.next().await.unwrap();
1276
1277    ep_store.update(cx, |ep_store, cx| {
1278        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1279    });
1280
1281    let (request2, respond_second) = requests.predict.next().await.unwrap();
1282
1283    // wait for throttle, so requests are sent
1284    cx.run_until_parked();
1285
1286    ep_store.update(cx, |ep_store, cx| {
1287        // start a third request
1288        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1289
1290        // 2 are pending, so 2nd is cancelled
1291        assert_eq!(
1292            ep_store
1293                .get_or_init_project(&project, cx)
1294                .cancelled_predictions
1295                .iter()
1296                .copied()
1297                .collect::<Vec<_>>(),
1298            [1]
1299        );
1300    });
1301
1302    // wait for throttle
1303    cx.run_until_parked();
1304
1305    let (request3, respond_third) = requests.predict.next().await.unwrap();
1306
1307    let first_response = model_response(&request1, SIMPLE_DIFF);
1308    let first_id = first_response.request_id.clone();
1309    respond_first.send(first_response).unwrap();
1310
1311    cx.run_until_parked();
1312
1313    ep_store.update(cx, |ep_store, cx| {
1314        // current prediction is first
1315        assert_eq!(
1316            ep_store
1317                .prediction_at(&buffer, None, &project, cx)
1318                .unwrap()
1319                .id
1320                .0,
1321            first_id
1322        );
1323    });
1324
1325    let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1326    let cancelled_id = cancelled_response.request_id.clone();
1327    respond_second.send(cancelled_response).unwrap();
1328
1329    cx.run_until_parked();
1330
1331    ep_store.update(cx, |ep_store, cx| {
1332        // current prediction is still first, since second was cancelled
1333        assert_eq!(
1334            ep_store
1335                .prediction_at(&buffer, None, &project, cx)
1336                .unwrap()
1337                .id
1338                .0,
1339            first_id
1340        );
1341    });
1342
1343    let third_response = model_response(&request3, SIMPLE_DIFF);
1344    let third_response_id = third_response.request_id.clone();
1345    respond_third.send(third_response).unwrap();
1346
1347    cx.run_until_parked();
1348
1349    ep_store.update(cx, |ep_store, cx| {
1350        // third completes and replaces first
1351        assert_eq!(
1352            ep_store
1353                .prediction_at(&buffer, None, &project, cx)
1354                .unwrap()
1355                .id
1356                .0,
1357            third_response_id
1358        );
1359    });
1360
1361    // second is reported as rejected
1362    let (reject_request, _) = requests.reject.next().await.unwrap();
1363
1364    cx.run_until_parked();
1365
1366    assert_eq!(
1367        &reject_request.rejections,
1368        &[
1369            EditPredictionRejection {
1370                request_id: cancelled_id,
1371                reason: EditPredictionRejectReason::Canceled,
1372                was_shown: false,
1373                model_version: None,
1374            },
1375            EditPredictionRejection {
1376                request_id: first_id,
1377                reason: EditPredictionRejectReason::Replaced,
1378                was_shown: false,
1379                model_version: None,
1380            }
1381        ]
1382    );
1383}
1384
1385#[gpui::test]
1386async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
1387    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1388
1389    let fs = FakeFs::new(cx.executor());
1390    fs.insert_tree(
1391        "/root",
1392        json!({
1393            "foo.md":  "Hello!\nHow\nBye\n",
1394            "bar.md": "Hola!\nComo\nAdios\n"
1395        }),
1396    )
1397    .await;
1398    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1399
1400    let buffer = project
1401        .update(cx, |project, cx| {
1402            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1403            project.set_active_path(Some(path.clone()), cx);
1404            project.open_buffer(path, cx)
1405        })
1406        .await
1407        .unwrap();
1408    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1409    let position = snapshot.anchor_before(language::Point::new(1, 3));
1410
1411    ep_store.update(cx, |ep_store, cx| {
1412        ep_store.register_project(&project, cx);
1413        ep_store.register_buffer(&buffer, &project, cx);
1414    });
1415
1416    // First edit request - no prior edit, so not throttled.
1417    ep_store.update(cx, |ep_store, cx| {
1418        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1419    });
1420    let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
1421    edit_response_tx.send(empty_response()).unwrap();
1422    cx.run_until_parked();
1423
1424    let diagnostic = lsp::Diagnostic {
1425        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1426        severity: Some(lsp::DiagnosticSeverity::ERROR),
1427        message: "Sentence is incomplete".to_string(),
1428        ..Default::default()
1429    };
1430
1431    // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
1432    project.update(cx, |project, cx| {
1433        project.lsp_store().update(cx, |lsp_store, cx| {
1434            lsp_store
1435                .update_diagnostics(
1436                    LanguageServerId(0),
1437                    lsp::PublishDiagnosticsParams {
1438                        uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
1439                        diagnostics: vec![diagnostic],
1440                        version: None,
1441                    },
1442                    None,
1443                    language::DiagnosticSourceKind::Pushed,
1444                    &[],
1445                    cx,
1446                )
1447                .unwrap();
1448        });
1449    });
1450    let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
1451    jump_response_tx.send(empty_response()).unwrap();
1452    cx.run_until_parked();
1453
1454    // Second edit request - should be throttled by the first edit.
1455    ep_store.update(cx, |ep_store, cx| {
1456        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1457    });
1458    assert_no_predict_request_ready(&mut requests.predict);
1459
1460    // Second jump request - should be throttled by the first jump.
1461    ep_store.update(cx, |ep_store, cx| {
1462        ep_store.refresh_prediction_from_diagnostics(
1463            project.clone(),
1464            DiagnosticSearchScope::Global,
1465            cx,
1466        );
1467    });
1468    assert_no_predict_request_ready(&mut requests.predict);
1469
1470    // Wait for both throttles to expire.
1471    cx.background_executor
1472        .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
1473    cx.background_executor.run_until_parked();
1474    cx.run_until_parked();
1475
1476    // Both requests should now go through.
1477    let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
1478    response_tx_1.send(empty_response()).unwrap();
1479    cx.run_until_parked();
1480
1481    let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
1482    response_tx_2.send(empty_response()).unwrap();
1483    cx.run_until_parked();
1484}
1485
1486#[gpui::test]
1487async fn test_rejections_flushing(cx: &mut TestAppContext) {
1488    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1489
1490    ep_store.update(cx, |ep_store, cx| {
1491        ep_store.reject_prediction(
1492            EditPredictionId("test-1".into()),
1493            EditPredictionRejectReason::Discarded,
1494            false,
1495            None,
1496            cx,
1497        );
1498        ep_store.reject_prediction(
1499            EditPredictionId("test-2".into()),
1500            EditPredictionRejectReason::Canceled,
1501            true,
1502            None,
1503            cx,
1504        );
1505    });
1506
1507    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1508    cx.run_until_parked();
1509
1510    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1511    respond_tx.send(()).unwrap();
1512
1513    // batched
1514    assert_eq!(reject_request.rejections.len(), 2);
1515    assert_eq!(
1516        reject_request.rejections[0],
1517        EditPredictionRejection {
1518            request_id: "test-1".to_string(),
1519            reason: EditPredictionRejectReason::Discarded,
1520            was_shown: false,
1521            model_version: None,
1522        }
1523    );
1524    assert_eq!(
1525        reject_request.rejections[1],
1526        EditPredictionRejection {
1527            request_id: "test-2".to_string(),
1528            reason: EditPredictionRejectReason::Canceled,
1529            was_shown: true,
1530            model_version: None,
1531        }
1532    );
1533
1534    // Reaching batch size limit sends without debounce
1535    ep_store.update(cx, |ep_store, cx| {
1536        for i in 0..70 {
1537            ep_store.reject_prediction(
1538                EditPredictionId(format!("batch-{}", i).into()),
1539                EditPredictionRejectReason::Discarded,
1540                false,
1541                None,
1542                cx,
1543            );
1544        }
1545    });
1546
1547    // First MAX/2 items are sent immediately
1548    cx.run_until_parked();
1549    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1550    respond_tx.send(()).unwrap();
1551
1552    assert_eq!(reject_request.rejections.len(), 50);
1553    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1554    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1555
1556    // Remaining items are debounced with the next batch
1557    cx.executor().advance_clock(Duration::from_secs(15));
1558    cx.run_until_parked();
1559
1560    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1561    respond_tx.send(()).unwrap();
1562
1563    assert_eq!(reject_request.rejections.len(), 20);
1564    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1565    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1566
1567    // Request failure
1568    ep_store.update(cx, |ep_store, cx| {
1569        ep_store.reject_prediction(
1570            EditPredictionId("retry-1".into()),
1571            EditPredictionRejectReason::Discarded,
1572            false,
1573            None,
1574            cx,
1575        );
1576    });
1577
1578    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1579    cx.run_until_parked();
1580
1581    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1582    assert_eq!(reject_request.rejections.len(), 1);
1583    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1584    // Simulate failure
1585    drop(_respond_tx);
1586
1587    // Add another rejection
1588    ep_store.update(cx, |ep_store, cx| {
1589        ep_store.reject_prediction(
1590            EditPredictionId("retry-2".into()),
1591            EditPredictionRejectReason::Discarded,
1592            false,
1593            None,
1594            cx,
1595        );
1596    });
1597
1598    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1599    cx.run_until_parked();
1600
1601    // Retry should include both the failed item and the new one
1602    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1603    respond_tx.send(()).unwrap();
1604
1605    assert_eq!(reject_request.rejections.len(), 2);
1606    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1607    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1608}
1609
1610// Skipped until we start including diagnostics in prompt
1611// #[gpui::test]
1612// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1613//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1614//     let fs = FakeFs::new(cx.executor());
1615//     fs.insert_tree(
1616//         "/root",
1617//         json!({
1618//             "foo.md": "Hello!\nBye"
1619//         }),
1620//     )
1621//     .await;
1622//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1623
1624//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1625//     let diagnostic = lsp::Diagnostic {
1626//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1627//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1628//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1629//         ..Default::default()
1630//     };
1631
1632//     project.update(cx, |project, cx| {
1633//         project.lsp_store().update(cx, |lsp_store, cx| {
1634//             // Create some diagnostics
1635//             lsp_store
1636//                 .update_diagnostics(
1637//                     LanguageServerId(0),
1638//                     lsp::PublishDiagnosticsParams {
1639//                         uri: path_to_buffer_uri.clone(),
1640//                         diagnostics: vec![diagnostic],
1641//                         version: None,
1642//                     },
1643//                     None,
1644//                     language::DiagnosticSourceKind::Pushed,
1645//                     &[],
1646//                     cx,
1647//                 )
1648//                 .unwrap();
1649//         });
1650//     });
1651
1652//     let buffer = project
1653//         .update(cx, |project, cx| {
1654//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1655//             project.open_buffer(path, cx)
1656//         })
1657//         .await
1658//         .unwrap();
1659
1660//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1661//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1662
1663//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1664//         ep_store.request_prediction(&project, &buffer, position, cx)
1665//     });
1666
1667//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1668
1669//     assert_eq!(request.diagnostic_groups.len(), 1);
1670//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1671//         .unwrap();
1672//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1673//     assert_eq!(
1674//         value,
1675//         json!({
1676//             "entries": [{
1677//                 "range": {
1678//                     "start": 8,
1679//                     "end": 10
1680//                 },
1681//                 "diagnostic": {
1682//                     "source": null,
1683//                     "code": null,
1684//                     "code_description": null,
1685//                     "severity": 1,
1686//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1687//                     "markdown": null,
1688//                     "group_id": 0,
1689//                     "is_primary": true,
1690//                     "is_disk_based": false,
1691//                     "is_unnecessary": false,
1692//                     "source_kind": "Pushed",
1693//                     "data": null,
1694//                     "underline": true
1695//                 }
1696//             }],
1697//             "primary_ix": 0
1698//         })
1699//     );
1700// }
1701
1702// Generate a model response that would apply the given diff to the active file.
1703fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1704    let editable_range = request
1705        .input
1706        .excerpt_ranges
1707        .as_ref()
1708        .map(|r| zeta_prompt::excerpt_range_for_format(Default::default(), r).1)
1709        .unwrap_or(request.input.editable_range_in_excerpt.clone());
1710    let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string();
1711    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1712
1713    PredictEditsV3Response {
1714        request_id: Uuid::new_v4().to_string(),
1715        editable_range,
1716        output: new_excerpt,
1717        model_version: None,
1718    }
1719}
1720
1721fn empty_response() -> PredictEditsV3Response {
1722    PredictEditsV3Response {
1723        request_id: Uuid::new_v4().to_string(),
1724        editable_range: 0..0,
1725        output: String::new(),
1726        model_version: None,
1727    }
1728}
1729
1730fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1731    zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1732}
1733
1734fn assert_no_predict_request_ready(
1735    requests: &mut mpsc::UnboundedReceiver<(
1736        PredictEditsV3Request,
1737        oneshot::Sender<PredictEditsV3Response>,
1738    )>,
1739) {
1740    if requests.next().now_or_never().flatten().is_some() {
1741        panic!("Unexpected prediction request while throttled.");
1742    }
1743}
1744
1745struct RequestChannels {
1746    predict: mpsc::UnboundedReceiver<(
1747        PredictEditsV3Request,
1748        oneshot::Sender<PredictEditsV3Response>,
1749    )>,
1750    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1751}
1752
1753fn init_test_with_fake_client(
1754    cx: &mut TestAppContext,
1755) -> (Entity<EditPredictionStore>, RequestChannels) {
1756    cx.update(move |cx| {
1757        let settings_store = SettingsStore::test(cx);
1758        cx.set_global(settings_store);
1759        zlog::init_test();
1760
1761        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1762        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1763
1764        let http_client = FakeHttpClient::create({
1765            move |req| {
1766                let uri = req.uri().path().to_string();
1767                let mut body = req.into_body();
1768                let predict_req_tx = predict_req_tx.clone();
1769                let reject_req_tx = reject_req_tx.clone();
1770                async move {
1771                    let resp = match uri.as_str() {
1772                        "/client/llm_tokens" => serde_json::to_string(&json!({
1773                            "token": "test"
1774                        }))
1775                        .unwrap(),
1776                        "/predict_edits/v3" => {
1777                            let mut buf = Vec::new();
1778                            body.read_to_end(&mut buf).await.ok();
1779                            let decompressed = zstd::decode_all(&buf[..]).unwrap();
1780                            let req = serde_json::from_slice(&decompressed).unwrap();
1781
1782                            let (res_tx, res_rx) = oneshot::channel();
1783                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1784                            serde_json::to_string(&res_rx.await?).unwrap()
1785                        }
1786                        "/predict_edits/reject" => {
1787                            let mut buf = Vec::new();
1788                            body.read_to_end(&mut buf).await.ok();
1789                            let req = serde_json::from_slice(&buf).unwrap();
1790
1791                            let (res_tx, res_rx) = oneshot::channel();
1792                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1793                            serde_json::to_string(&res_rx.await?).unwrap()
1794                        }
1795                        _ => {
1796                            panic!("Unexpected path: {}", uri)
1797                        }
1798                    };
1799
1800                    Ok(Response::builder().body(resp.into()).unwrap())
1801                }
1802            }
1803        });
1804
1805        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1806        client.cloud_client().set_credentials(1, "test".into());
1807
1808        language_model::init(client.clone(), cx);
1809
1810        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1811        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1812
1813        (
1814            ep_store,
1815            RequestChannels {
1816                predict: predict_req_rx,
1817                reject: reject_req_rx,
1818            },
1819        )
1820    })
1821}
1822
1823#[gpui::test]
1824async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1825    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1826    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1827        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1828    });
1829
1830    let edit_preview = cx
1831        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1832        .await;
1833
1834    let prediction = EditPrediction {
1835        edits,
1836        cursor_position: None,
1837        edit_preview,
1838        buffer: buffer.clone(),
1839        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1840        id: EditPredictionId("the-id".into()),
1841        inputs: ZetaPromptInput {
1842            events: Default::default(),
1843            related_files: Default::default(),
1844            cursor_path: Path::new("").into(),
1845            cursor_excerpt: "".into(),
1846            editable_range_in_excerpt: 0..0,
1847            cursor_offset_in_excerpt: 0,
1848            excerpt_start_row: None,
1849            excerpt_ranges: None,
1850            preferred_model: None,
1851            in_open_source_repo: false,
1852            can_collect_data: false,
1853        },
1854        buffer_snapshotted_at: Instant::now(),
1855        response_received_at: Instant::now(),
1856        model_version: None,
1857    };
1858
1859    cx.update(|cx| {
1860        assert_eq!(
1861            from_completion_edits(
1862                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1863                &buffer,
1864                cx
1865            ),
1866            vec![(2..5, "REM".into()), (9..11, "".into())]
1867        );
1868
1869        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1870        assert_eq!(
1871            from_completion_edits(
1872                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1873                &buffer,
1874                cx
1875            ),
1876            vec![(2..2, "REM".into()), (6..8, "".into())]
1877        );
1878
1879        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1880        assert_eq!(
1881            from_completion_edits(
1882                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1883                &buffer,
1884                cx
1885            ),
1886            vec![(2..5, "REM".into()), (9..11, "".into())]
1887        );
1888
1889        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1890        assert_eq!(
1891            from_completion_edits(
1892                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1893                &buffer,
1894                cx
1895            ),
1896            vec![(3..3, "EM".into()), (7..9, "".into())]
1897        );
1898
1899        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1900        assert_eq!(
1901            from_completion_edits(
1902                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1903                &buffer,
1904                cx
1905            ),
1906            vec![(4..4, "M".into()), (8..10, "".into())]
1907        );
1908
1909        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1910        assert_eq!(
1911            from_completion_edits(
1912                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1913                &buffer,
1914                cx
1915            ),
1916            vec![(9..11, "".into())]
1917        );
1918
1919        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1920        assert_eq!(
1921            from_completion_edits(
1922                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1923                &buffer,
1924                cx
1925            ),
1926            vec![(4..4, "M".into()), (8..10, "".into())]
1927        );
1928
1929        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1930        assert_eq!(
1931            from_completion_edits(
1932                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1933                &buffer,
1934                cx
1935            ),
1936            vec![(4..4, "M".into())]
1937        );
1938
1939        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1940        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1941    })
1942}
1943
1944#[gpui::test]
1945async fn test_clean_up_diff(cx: &mut TestAppContext) {
1946    init_test(cx);
1947
1948    assert_eq!(
1949        apply_edit_prediction(
1950            indoc! {"
1951                    fn main() {
1952                        let word_1 = \"lorem\";
1953                        let range = word.len()..word.len();
1954                    }
1955                "},
1956            indoc! {"
1957                    fn main() {
1958                        let word_1 = \"lorem\";
1959                        let range = word_1.len()..word_1.len();
1960                    }
1961                "},
1962            cx,
1963        )
1964        .await,
1965        indoc! {"
1966                fn main() {
1967                    let word_1 = \"lorem\";
1968                    let range = word_1.len()..word_1.len();
1969                }
1970            "},
1971    );
1972
1973    assert_eq!(
1974        apply_edit_prediction(
1975            indoc! {"
1976                    fn main() {
1977                        let story = \"the quick\"
1978                    }
1979                "},
1980            indoc! {"
1981                    fn main() {
1982                        let story = \"the quick brown fox jumps over the lazy dog\";
1983                    }
1984                "},
1985            cx,
1986        )
1987        .await,
1988        indoc! {"
1989                fn main() {
1990                    let story = \"the quick brown fox jumps over the lazy dog\";
1991                }
1992            "},
1993    );
1994}
1995
1996#[gpui::test]
1997async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1998    init_test(cx);
1999
2000    let buffer_content = "lorem\n";
2001    let completion_response = "lorem\nipsum\n";
2002
2003    assert_eq!(
2004        apply_edit_prediction(buffer_content, completion_response, cx).await,
2005        "lorem\nipsum\n"
2006    );
2007}
2008
2009#[gpui::test]
2010async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
2011    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
2012    // When the buffer ends without a trailing newline, but the model returns output
2013    // with a trailing newline, zeta2 should normalize both sides before diffing
2014    // so no spurious newline is inserted.
2015    let (ep_store, mut requests) = init_test_with_fake_client(cx);
2016    let fs = FakeFs::new(cx.executor());
2017
2018    // Single line buffer with no trailing newline
2019    fs.insert_tree(
2020        "/root",
2021        json!({
2022            "foo.txt": "hello"
2023        }),
2024    )
2025    .await;
2026    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2027
2028    let buffer = project
2029        .update(cx, |project, cx| {
2030            let path = project
2031                .find_project_path(path!("root/foo.txt"), cx)
2032                .unwrap();
2033            project.open_buffer(path, cx)
2034        })
2035        .await
2036        .unwrap();
2037
2038    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2039    let position = snapshot.anchor_before(language::Point::new(0, 5));
2040
2041    ep_store.update(cx, |ep_store, cx| {
2042        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2043    });
2044
2045    let (request, respond_tx) = requests.predict.next().await.unwrap();
2046
2047    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
2048    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
2049    let excerpt_length = request.input.cursor_excerpt.len();
2050    let response = PredictEditsV3Response {
2051        request_id: Uuid::new_v4().to_string(),
2052        output: "hello world\n".to_string(),
2053        editable_range: 0..excerpt_length,
2054        model_version: None,
2055    };
2056    respond_tx.send(response).unwrap();
2057
2058    cx.run_until_parked();
2059
2060    // The prediction should insert " world" without adding a newline
2061    ep_store.update(cx, |ep_store, cx| {
2062        let prediction = ep_store
2063            .prediction_at(&buffer, None, &project, cx)
2064            .expect("should have prediction");
2065        let edits: Vec<_> = prediction
2066            .edits
2067            .iter()
2068            .map(|(range, text)| {
2069                let snapshot = buffer.read(cx).snapshot();
2070                (range.to_offset(&snapshot), text.clone())
2071            })
2072            .collect();
2073        assert_eq!(edits, vec![(5..5, " world".into())]);
2074    });
2075}
2076
2077fn init_test(cx: &mut TestAppContext) {
2078    cx.update(|cx| {
2079        let settings_store = SettingsStore::test(cx);
2080        cx.set_global(settings_store);
2081    });
2082}
2083
2084async fn apply_edit_prediction(
2085    buffer_content: &str,
2086    completion_response: &str,
2087    cx: &mut TestAppContext,
2088) -> String {
2089    let fs = project::FakeFs::new(cx.executor());
2090    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2091    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2092    let (ep_store, response) = make_test_ep_store(&project, cx).await;
2093    *response.lock() = completion_response.to_string();
2094    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2095    buffer.update(cx, |buffer, cx| {
2096        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2097    });
2098    buffer.read_with(cx, |buffer, _| buffer.text())
2099}
2100
2101async fn run_edit_prediction(
2102    buffer: &Entity<Buffer>,
2103    project: &Entity<Project>,
2104    ep_store: &Entity<EditPredictionStore>,
2105    cx: &mut TestAppContext,
2106) -> EditPrediction {
2107    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2108    ep_store.update(cx, |ep_store, cx| {
2109        ep_store.register_buffer(buffer, &project, cx)
2110    });
2111    cx.background_executor.run_until_parked();
2112    let prediction_task = ep_store.update(cx, |ep_store, cx| {
2113        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2114    });
2115    prediction_task.await.unwrap().unwrap().prediction.unwrap()
2116}
2117
2118async fn make_test_ep_store(
2119    project: &Entity<Project>,
2120    cx: &mut TestAppContext,
2121) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
2122    let default_response = "hello world\n".to_string();
2123    let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
2124    let http_client = FakeHttpClient::create({
2125        let completion_response = completion_response.clone();
2126        let mut next_request_id = 0;
2127        move |req| {
2128            let completion_response = completion_response.clone();
2129            let method = req.method().clone();
2130            let uri = req.uri().path().to_string();
2131            let mut body = req.into_body();
2132            async move {
2133                match (method, uri.as_str()) {
2134                    (Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2135                        .status(200)
2136                        .body(
2137                            serde_json::to_string(&CreateLlmTokenResponse {
2138                                token: LlmToken("the-llm-token".to_string()),
2139                            })
2140                            .unwrap()
2141                            .into(),
2142                        )
2143                        .unwrap()),
2144                    (Method::POST, "/predict_edits/v3") => {
2145                        let mut buf = Vec::new();
2146                        body.read_to_end(&mut buf).await.ok();
2147                        let decompressed = zstd::decode_all(&buf[..]).unwrap();
2148                        let req: PredictEditsV3Request =
2149                            serde_json::from_slice(&decompressed).unwrap();
2150
2151                        next_request_id += 1;
2152                        Ok(http_client::Response::builder()
2153                            .status(200)
2154                            .body(
2155                                serde_json::to_string(&PredictEditsV3Response {
2156                                    request_id: format!("request-{next_request_id}"),
2157                                    editable_range: 0..req.input.cursor_excerpt.len(),
2158                                    output: completion_response.lock().clone(),
2159                                    model_version: None,
2160                                })
2161                                .unwrap()
2162                                .into(),
2163                            )
2164                            .unwrap())
2165                    }
2166                    _ => Ok(http_client::Response::builder()
2167                        .status(404)
2168                        .body("Not Found".to_string().into())
2169                        .unwrap()),
2170                }
2171            }
2172        }
2173    });
2174
2175    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2176    cx.update(|cx| {
2177        RefreshLlmTokenListener::register(client.clone(), cx);
2178    });
2179    let _server = FakeServer::for_client(42, &client, cx).await;
2180
2181    let ep_store = cx.new(|cx| {
2182        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2183        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2184
2185        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2186        for worktree in worktrees {
2187            let worktree_id = worktree.read(cx).id();
2188            ep_store
2189                .get_or_init_project(project, cx)
2190                .license_detection_watchers
2191                .entry(worktree_id)
2192                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2193        }
2194
2195        ep_store
2196    });
2197
2198    (ep_store, completion_response)
2199}
2200
2201fn to_completion_edits(
2202    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2203    buffer: &Entity<Buffer>,
2204    cx: &App,
2205) -> Vec<(Range<Anchor>, Arc<str>)> {
2206    let buffer = buffer.read(cx);
2207    iterator
2208        .into_iter()
2209        .map(|(range, text)| {
2210            (
2211                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2212                text,
2213            )
2214        })
2215        .collect()
2216}
2217
2218fn from_completion_edits(
2219    editor_edits: &[(Range<Anchor>, Arc<str>)],
2220    buffer: &Entity<Buffer>,
2221    cx: &App,
2222) -> Vec<(Range<usize>, Arc<str>)> {
2223    let buffer = buffer.read(cx);
2224    editor_edits
2225        .iter()
2226        .map(|(range, text)| {
2227            (
2228                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2229                text.clone(),
2230            )
2231        })
2232        .collect()
2233}
2234
2235#[gpui::test]
2236async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2237    init_test(cx);
2238
2239    let fs = FakeFs::new(cx.executor());
2240    fs.insert_tree(
2241        "/project",
2242        serde_json::json!({
2243            "main.rs": "fn main() {\n    \n}\n"
2244        }),
2245    )
2246    .await;
2247
2248    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2249
2250    let http_client = FakeHttpClient::create(|_req| async move {
2251        Ok(gpui::http_client::Response::builder()
2252            .status(401)
2253            .body("Unauthorized".into())
2254            .unwrap())
2255    });
2256
2257    let client =
2258        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2259    cx.update(|cx| {
2260        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2261    });
2262
2263    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2264
2265    let buffer = project
2266        .update(cx, |project, cx| {
2267            let path = project
2268                .find_project_path(path!("/project/main.rs"), cx)
2269                .unwrap();
2270            project.open_buffer(path, cx)
2271        })
2272        .await
2273        .unwrap();
2274
2275    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2276    ep_store.update(cx, |ep_store, cx| {
2277        ep_store.register_buffer(&buffer, &project, cx)
2278    });
2279    cx.background_executor.run_until_parked();
2280
2281    let completion_task = ep_store.update(cx, |ep_store, cx| {
2282        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2283        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2284    });
2285
2286    let result = completion_task.await;
2287    assert!(
2288        result.is_err(),
2289        "Without authentication and without custom URL, prediction should fail"
2290    );
2291}
2292
2293#[gpui::test]
2294fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2295    let buffer = cx.new(|cx| {
2296        Buffer::local(
2297            indoc! {"
2298                zero
2299                one
2300                two
2301                three
2302                four
2303                five
2304                six
2305                seven
2306                eight
2307                nine
2308                ten
2309                eleven
2310                twelve
2311                thirteen
2312                fourteen
2313                fifteen
2314                sixteen
2315                seventeen
2316                eighteen
2317                nineteen
2318                twenty
2319                twenty-one
2320                twenty-two
2321                twenty-three
2322                twenty-four
2323            "},
2324            cx,
2325        )
2326    });
2327
2328    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2329
2330    buffer.update(cx, |buffer, cx| {
2331        let point = Point::new(12, 0);
2332        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2333        let point = Point::new(8, 0);
2334        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2335    });
2336
2337    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2338
2339    let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2340
2341    assert_eq!(
2342        diff,
2343        indoc! {"
2344            @@ -6,10 +6,12 @@
2345             five
2346             six
2347             seven
2348            +FIRST INSERTION
2349             eight
2350             nine
2351             ten
2352             eleven
2353            +SECOND INSERTION
2354             twelve
2355             thirteen
2356             fourteen
2357            "}
2358    );
2359}
2360
2361#[gpui::test]
2362async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
2363    fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
2364        let collab_replica = clock::ReplicaId::new(10);
2365        let anchor = buffer.read_with(cx, |buffer, _| {
2366            buffer.snapshot().anchor_before(Point::new(row, 0))
2367        });
2368        let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
2369            id: 1,
2370            start: anchor,
2371            end: anchor,
2372            reversed: false,
2373            goal: SelectionGoal::None,
2374        }]);
2375        buffer.update(cx, |buffer, cx| {
2376            buffer.apply_ops(
2377                [Operation::UpdateSelections {
2378                    selections,
2379                    lamport_timestamp: clock::Lamport {
2380                        replica_id: collab_replica,
2381                        value: 1,
2382                    },
2383                    line_mode: false,
2384                    cursor_shape: CursorShape::Bar,
2385                }],
2386                cx,
2387            );
2388        });
2389    }
2390
2391    fn publish_diagnostics(
2392        uri_path: &'static str,
2393        rows: &[u32],
2394        project: &Entity<Project>,
2395        cx: &mut TestAppContext,
2396    ) {
2397        let diagnostics: Vec<_> = rows
2398            .iter()
2399            .map(|&row| lsp::Diagnostic {
2400                range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
2401                severity: Some(lsp::DiagnosticSeverity::ERROR),
2402                message: format!("error at row {row}"),
2403                ..Default::default()
2404            })
2405            .collect();
2406        project.update(cx, |project, cx| {
2407            project.lsp_store().update(cx, |lsp_store, cx| {
2408                lsp_store
2409                    .update_diagnostics(
2410                        LanguageServerId(0),
2411                        lsp::PublishDiagnosticsParams {
2412                            uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
2413                            diagnostics,
2414                            version: None,
2415                        },
2416                        None,
2417                        language::DiagnosticSourceKind::Pushed,
2418                        &[],
2419                        cx,
2420                    )
2421                    .expect("failed to update diagnostics");
2422            });
2423        });
2424    }
2425
2426    init_test(cx);
2427
2428    let mut lines = String::new();
2429    for i in 0..60 {
2430        lines.push_str(&format!("line {i}\n"));
2431    }
2432
2433    let fs = FakeFs::new(cx.executor());
2434    fs.insert_tree(
2435        "/root",
2436        json!({
2437            "active.txt": lines,
2438            "collab_file.txt": "error here\nsecond line\n",
2439            "free_file.txt": "another error\nsecond line\n",
2440        }),
2441    )
2442    .await;
2443    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2444
2445    let active_buffer = project
2446        .update(cx, |project, cx| {
2447            let path = project
2448                .find_project_path(path!("/root/active.txt"), cx)
2449                .expect("active.txt not found");
2450            project.set_active_path(Some(path.clone()), cx);
2451            project.open_buffer(path, cx)
2452        })
2453        .await
2454        .expect("failed to open active buffer");
2455
2456    set_collaborator_cursor(&active_buffer, 5, cx);
2457
2458    publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
2459
2460    cx.run_until_parked();
2461
2462    let cursor_point = Point::new(25, 0);
2463    let empty_search_range: Range<Point> = Default::default();
2464
2465    let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2466    let result = EditPredictionStore::next_diagnostic_location(
2467        active_buffer.clone(),
2468        &snapshot,
2469        empty_search_range.clone(),
2470        cursor_point,
2471        &project,
2472        &mut cx.to_async(),
2473    )
2474    .await
2475    .expect("next_diagnostic_location failed");
2476
2477    let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
2478    assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
2479    let result_row = result_buffer.read_with(cx, |buffer, _| {
2480        result_anchor.to_point(&buffer.snapshot()).row
2481    });
2482    assert_ne!(
2483        result_row, 3,
2484        "row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
2485    );
2486    assert!(
2487        result_row == 25 || result_row == 50,
2488        "expected row 25 or 50, got {result_row}"
2489    );
2490
2491    let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2492    let near_cursor_point = Point::new(4, 0);
2493    let result_near = EditPredictionStore::next_diagnostic_location(
2494        active_buffer.clone(),
2495        &snapshot_near,
2496        empty_search_range.clone(),
2497        near_cursor_point,
2498        &project,
2499        &mut cx.to_async(),
2500    )
2501    .await
2502    .expect("next_diagnostic_location failed");
2503
2504    let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
2505    let near_row =
2506        active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
2507    assert_eq!(
2508        near_row, 3,
2509        "row 3 should be included when local cursor (row 4) is also near the collaborator"
2510    );
2511
2512    let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2513    let far_cursor_point = Point::new(50, 0);
2514    let result_far = EditPredictionStore::next_diagnostic_location(
2515        active_buffer.clone(),
2516        &snapshot_far,
2517        empty_search_range.clone(),
2518        far_cursor_point,
2519        &project,
2520        &mut cx.to_async(),
2521    )
2522    .await
2523    .expect("next_diagnostic_location failed");
2524
2525    let (_, far_anchor) = result_far.expect("expected a diagnostic location");
2526    let far_row =
2527        active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
2528    assert_eq!(
2529        far_row, 50,
2530        "row 50 is near local cursor (row 50) and far from collaborator, should be picked"
2531    );
2532
2533    publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
2534    publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
2535    cx.run_until_parked();
2536
2537    let collab_buffer = project
2538        .update(cx, |project, cx| {
2539            let path = project
2540                .find_project_path(path!("/root/collab_file.txt"), cx)
2541                .expect("collab_file.txt not found");
2542            project.open_buffer(path, cx)
2543        })
2544        .await
2545        .expect("failed to open collab buffer");
2546
2547    set_collaborator_cursor(&collab_buffer, 0, cx);
2548    cx.run_until_parked();
2549
2550    let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
2551    let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2552    let result_cross = EditPredictionStore::next_diagnostic_location(
2553        active_buffer.clone(),
2554        &snapshot_cross,
2555        no_same_file_search_range,
2556        Point::new(0, 0),
2557        &project,
2558        &mut cx.to_async(),
2559    )
2560    .await
2561    .expect("cross-file next_diagnostic_location failed");
2562
2563    let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
2564    let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
2565        buffer
2566            .file()
2567            .expect("buffer should have a file")
2568            .full_path(cx)
2569    });
2570    assert_eq!(
2571        cross_path,
2572        Path::new(path!("root/free_file.txt")),
2573        "should skip collab_file.txt (has collaborator) and pick free_file.txt"
2574    );
2575}
2576
2577#[ctor::ctor]
2578fn init_logger() {
2579    zlog::init_test();
2580}