edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
   3use client::{UserStore, test::FakeServer};
   4use clock::FakeSystemClock;
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
   8    predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
   9};
  10use futures::{
  11    AsyncReadExt, FutureExt, StreamExt,
  12    channel::{mpsc, oneshot},
  13};
  14use gpui::App;
  15use gpui::{
  16    Entity, TestAppContext,
  17    http_client::{FakeHttpClient, Response},
  18};
  19use indoc::indoc;
  20use language::{Anchor, Buffer, CursorShape, Operation, Point, Selection, SelectionGoal};
  21use lsp::LanguageServerId;
  22use parking_lot::Mutex;
  23use pretty_assertions::{assert_eq, assert_matches};
  24use project::{FakeFs, Project};
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{path::Path, sync::Arc, time::Duration};
  28use util::path;
  29use uuid::Uuid;
  30use zeta_prompt::ZetaPromptInput;
  31
  32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  33
  34#[gpui::test]
  35async fn test_current_state(cx: &mut TestAppContext) {
  36    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  37    let fs = FakeFs::new(cx.executor());
  38    fs.insert_tree(
  39        "/root",
  40        json!({
  41            "1.txt": "Hello!\nHow\nBye\n",
  42            "2.txt": "Hola!\nComo\nAdios\n"
  43        }),
  44    )
  45    .await;
  46    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  47
  48    let buffer1 = project
  49        .update(cx, |project, cx| {
  50            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  51            project.set_active_path(Some(path.clone()), cx);
  52            project.open_buffer(path, cx)
  53        })
  54        .await
  55        .unwrap();
  56    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  57    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  58
  59    ep_store.update(cx, |ep_store, cx| {
  60        ep_store.register_project(&project, cx);
  61        ep_store.register_buffer(&buffer1, &project, cx);
  62    });
  63
  64    // Prediction for current file
  65
  66    ep_store.update(cx, |ep_store, cx| {
  67        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  68    });
  69    let (request, respond_tx) = requests.predict.next().await.unwrap();
  70
  71    respond_tx
  72        .send(model_response(
  73            &request,
  74            indoc! {r"
  75                --- a/root/1.txt
  76                +++ b/root/1.txt
  77                @@ ... @@
  78                 Hello!
  79                -How
  80                +How are you?
  81                 Bye
  82            "},
  83        ))
  84        .unwrap();
  85
  86    cx.run_until_parked();
  87
  88    ep_store.update(cx, |ep_store, cx| {
  89        let prediction = ep_store
  90            .prediction_at(&buffer1, None, &project, cx)
  91            .unwrap();
  92        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  93    });
  94
  95    ep_store.update(cx, |ep_store, cx| {
  96        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
  97    });
  98
  99    // Prediction for diagnostic in another file
 100
 101    let diagnostic = lsp::Diagnostic {
 102        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 103        severity: Some(lsp::DiagnosticSeverity::ERROR),
 104        message: "Sentence is incomplete".to_string(),
 105        ..Default::default()
 106    };
 107
 108    project.update(cx, |project, cx| {
 109        project.lsp_store().update(cx, |lsp_store, cx| {
 110            lsp_store
 111                .update_diagnostics(
 112                    LanguageServerId(0),
 113                    lsp::PublishDiagnosticsParams {
 114                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 115                        diagnostics: vec![diagnostic],
 116                        version: None,
 117                    },
 118                    None,
 119                    language::DiagnosticSourceKind::Pushed,
 120                    &[],
 121                    cx,
 122                )
 123                .unwrap();
 124        });
 125    });
 126
 127    let (request, respond_tx) = requests.predict.next().await.unwrap();
 128    respond_tx
 129        .send(model_response(
 130            &request,
 131            indoc! {r#"
 132                --- a/root/2.txt
 133                +++ b/root/2.txt
 134                @@ ... @@
 135                 Hola!
 136                -Como
 137                +Como estas?
 138                 Adios
 139            "#},
 140        ))
 141        .unwrap();
 142    cx.run_until_parked();
 143
 144    ep_store.update(cx, |ep_store, cx| {
 145        let prediction = ep_store
 146            .prediction_at(&buffer1, None, &project, cx)
 147            .unwrap();
 148        assert_matches!(
 149            prediction,
 150            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 151        );
 152    });
 153
 154    let buffer2 = project
 155        .update(cx, |project, cx| {
 156            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 157            project.open_buffer(path, cx)
 158        })
 159        .await
 160        .unwrap();
 161
 162    ep_store.update(cx, |ep_store, cx| {
 163        let prediction = ep_store
 164            .prediction_at(&buffer2, None, &project, cx)
 165            .unwrap();
 166        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 167    });
 168}
 169
 170#[gpui::test]
 171async fn test_simple_request(cx: &mut TestAppContext) {
 172    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 173    let fs = FakeFs::new(cx.executor());
 174    fs.insert_tree(
 175        "/root",
 176        json!({
 177            "foo.md":  "Hello!\nHow\nBye\n"
 178        }),
 179    )
 180    .await;
 181    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 182
 183    let buffer = project
 184        .update(cx, |project, cx| {
 185            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 186            project.open_buffer(path, cx)
 187        })
 188        .await
 189        .unwrap();
 190    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 191    let position = snapshot.anchor_before(language::Point::new(1, 3));
 192
 193    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 194        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 195    });
 196
 197    let (request, respond_tx) = requests.predict.next().await.unwrap();
 198
 199    // TODO Put back when we have a structured request again
 200    // assert_eq!(
 201    //     request.excerpt_path.as_ref(),
 202    //     Path::new(path!("root/foo.md"))
 203    // );
 204    // assert_eq!(
 205    //     request.cursor_point,
 206    //     Point {
 207    //         line: Line(1),
 208    //         column: 3
 209    //     }
 210    // );
 211
 212    respond_tx
 213        .send(model_response(
 214            &request,
 215            indoc! { r"
 216                --- a/root/foo.md
 217                +++ b/root/foo.md
 218                @@ ... @@
 219                 Hello!
 220                -How
 221                +How are you?
 222                 Bye
 223            "},
 224        ))
 225        .unwrap();
 226
 227    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 228
 229    assert_eq!(prediction.edits.len(), 1);
 230    assert_eq!(
 231        prediction.edits[0].0.to_point(&snapshot).start,
 232        language::Point::new(1, 3)
 233    );
 234    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 235}
 236
 237#[gpui::test]
 238async fn test_request_events(cx: &mut TestAppContext) {
 239    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 240    let fs = FakeFs::new(cx.executor());
 241    fs.insert_tree(
 242        "/root",
 243        json!({
 244            "foo.md": "Hello!\n\nBye\n"
 245        }),
 246    )
 247    .await;
 248    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 249
 250    let buffer = project
 251        .update(cx, |project, cx| {
 252            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 253            project.open_buffer(path, cx)
 254        })
 255        .await
 256        .unwrap();
 257
 258    ep_store.update(cx, |ep_store, cx| {
 259        ep_store.register_buffer(&buffer, &project, cx);
 260    });
 261
 262    buffer.update(cx, |buffer, cx| {
 263        buffer.edit(vec![(7..7, "How")], None, cx);
 264    });
 265
 266    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 267    let position = snapshot.anchor_before(language::Point::new(1, 3));
 268
 269    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 270        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 271    });
 272
 273    let (request, respond_tx) = requests.predict.next().await.unwrap();
 274
 275    let prompt = prompt_from_request(&request);
 276    assert!(
 277        prompt.contains(indoc! {"
 278        --- a/root/foo.md
 279        +++ b/root/foo.md
 280        @@ -1,3 +1,3 @@
 281         Hello!
 282        -
 283        +How
 284         Bye
 285    "}),
 286        "{prompt}"
 287    );
 288
 289    respond_tx
 290        .send(model_response(
 291            &request,
 292            indoc! {r#"
 293                --- a/root/foo.md
 294                +++ b/root/foo.md
 295                @@ ... @@
 296                 Hello!
 297                -How
 298                +How are you?
 299                 Bye
 300        "#},
 301        ))
 302        .unwrap();
 303
 304    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 305
 306    assert_eq!(prediction.edits.len(), 1);
 307    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 308}
 309
 310#[gpui::test]
 311async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 312    let (ep_store, _requests) = init_test_with_fake_client(cx);
 313    let fs = FakeFs::new(cx.executor());
 314    fs.insert_tree(
 315        "/root",
 316        json!({
 317            "foo.md": "Hello!\n\nBye\n"
 318        }),
 319    )
 320    .await;
 321    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 322
 323    let buffer = project
 324        .update(cx, |project, cx| {
 325            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 326            project.open_buffer(path, cx)
 327        })
 328        .await
 329        .unwrap();
 330
 331    ep_store.update(cx, |ep_store, cx| {
 332        ep_store.register_buffer(&buffer, &project, cx);
 333    });
 334
 335    // First burst: insert "How"
 336    buffer.update(cx, |buffer, cx| {
 337        buffer.edit(vec![(7..7, "How")], None, cx);
 338    });
 339
 340    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 341    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 342    cx.run_until_parked();
 343
 344    // Second burst: append " are you?" immediately after "How" on the same line.
 345    //
 346    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 347    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 348    buffer.update(cx, |buffer, cx| {
 349        buffer.edit(vec![(10..10, " are you?")], None, cx);
 350    });
 351
 352    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 353    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 354    buffer.update(cx, |buffer, cx| {
 355        buffer.edit(vec![(19..19, "!")], None, cx);
 356    });
 357
 358    // With time-based splitting, there are two distinct events.
 359    let events = ep_store.update(cx, |ep_store, cx| {
 360        ep_store.edit_history_for_project(&project, cx)
 361    });
 362    assert_eq!(events.len(), 2);
 363    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 364    assert_eq!(
 365        diff.as_str(),
 366        indoc! {"
 367            @@ -1,3 +1,3 @@
 368             Hello!
 369            -
 370            +How
 371             Bye
 372        "}
 373    );
 374
 375    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 376    assert_eq!(
 377        diff.as_str(),
 378        indoc! {"
 379            @@ -1,3 +1,3 @@
 380             Hello!
 381            -How
 382            +How are you?!
 383             Bye
 384        "}
 385    );
 386}
 387
 388#[gpui::test]
 389async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
 390    let (ep_store, _requests) = init_test_with_fake_client(cx);
 391    let fs = FakeFs::new(cx.executor());
 392
 393    // Create a file with 30 lines to test line-based coalescing
 394    let content = (1..=30)
 395        .map(|i| format!("Line {}\n", i))
 396        .collect::<String>();
 397    fs.insert_tree(
 398        "/root",
 399        json!({
 400            "foo.md": content
 401        }),
 402    )
 403    .await;
 404    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 405
 406    let buffer = project
 407        .update(cx, |project, cx| {
 408            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 409            project.open_buffer(path, cx)
 410        })
 411        .await
 412        .unwrap();
 413
 414    ep_store.update(cx, |ep_store, cx| {
 415        ep_store.register_buffer(&buffer, &project, cx);
 416    });
 417
 418    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 419    buffer.update(cx, |buffer, cx| {
 420        let start = Point::new(10, 0).to_offset(buffer);
 421        let end = Point::new(13, 0).to_offset(buffer);
 422        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 423    });
 424
 425    let events = ep_store.update(cx, |ep_store, cx| {
 426        ep_store.edit_history_for_project(&project, cx)
 427    });
 428    assert_eq!(
 429        render_events(&events),
 430        indoc! {"
 431            @@ -8,9 +8,8 @@
 432             Line 8
 433             Line 9
 434             Line 10
 435            -Line 11
 436            -Line 12
 437            -Line 13
 438            +Middle A
 439            +Middle B
 440             Line 14
 441             Line 15
 442             Line 16
 443        "},
 444        "After first edit"
 445    );
 446
 447    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 448    // This tests that coalescing considers the START of the existing range
 449    buffer.update(cx, |buffer, cx| {
 450        let offset = Point::new(5, 0).to_offset(buffer);
 451        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 452    });
 453
 454    let events = ep_store.update(cx, |ep_store, cx| {
 455        ep_store.edit_history_for_project(&project, cx)
 456    });
 457    assert_eq!(
 458        render_events(&events),
 459        indoc! {"
 460            @@ -3,14 +3,14 @@
 461             Line 3
 462             Line 4
 463             Line 5
 464            +Above
 465             Line 6
 466             Line 7
 467             Line 8
 468             Line 9
 469             Line 10
 470            -Line 11
 471            -Line 12
 472            -Line 13
 473            +Middle A
 474            +Middle B
 475             Line 14
 476             Line 15
 477             Line 16
 478        "},
 479        "After inserting above (should coalesce)"
 480    );
 481
 482    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 483    // This tests that coalescing considers the END of the existing range
 484    buffer.update(cx, |buffer, cx| {
 485        let offset = Point::new(14, 0).to_offset(buffer);
 486        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 487    });
 488
 489    let events = ep_store.update(cx, |ep_store, cx| {
 490        ep_store.edit_history_for_project(&project, cx)
 491    });
 492    assert_eq!(
 493        render_events(&events),
 494        indoc! {"
 495            @@ -3,15 +3,16 @@
 496             Line 3
 497             Line 4
 498             Line 5
 499            +Above
 500             Line 6
 501             Line 7
 502             Line 8
 503             Line 9
 504             Line 10
 505            -Line 11
 506            -Line 12
 507            -Line 13
 508            +Middle A
 509            +Middle B
 510             Line 14
 511            +Below
 512             Line 15
 513             Line 16
 514             Line 17
 515        "},
 516        "After inserting below (should coalesce)"
 517    );
 518
 519    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 520    // This should NOT coalesce - creates a new event
 521    buffer.update(cx, |buffer, cx| {
 522        let offset = Point::new(25, 0).to_offset(buffer);
 523        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 524    });
 525
 526    let events = ep_store.update(cx, |ep_store, cx| {
 527        ep_store.edit_history_for_project(&project, cx)
 528    });
 529    assert_eq!(
 530        render_events(&events),
 531        indoc! {"
 532            @@ -3,15 +3,16 @@
 533             Line 3
 534             Line 4
 535             Line 5
 536            +Above
 537             Line 6
 538             Line 7
 539             Line 8
 540             Line 9
 541             Line 10
 542            -Line 11
 543            -Line 12
 544            -Line 13
 545            +Middle A
 546            +Middle B
 547             Line 14
 548            +Below
 549             Line 15
 550             Line 16
 551             Line 17
 552
 553            ---
 554            @@ -23,6 +23,7 @@
 555             Line 22
 556             Line 23
 557             Line 24
 558            +Far below
 559             Line 25
 560             Line 26
 561             Line 27
 562        "},
 563        "After inserting far below (should NOT coalesce)"
 564    );
 565}
 566
 567fn render_events(events: &[StoredEvent]) -> String {
 568    events
 569        .iter()
 570        .map(|e| {
 571            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 572            diff.as_str()
 573        })
 574        .collect::<Vec<_>>()
 575        .join("\n---\n")
 576}
 577
 578fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
 579    events
 580        .iter()
 581        .map(|e| {
 582            let zeta_prompt::Event::BufferChange {
 583                diff, predicted, ..
 584            } = e.event.as_ref();
 585            let prefix = if *predicted { "predicted" } else { "manual" };
 586            format!("{}\n{}", prefix, diff)
 587        })
 588        .collect()
 589}
 590
 591#[gpui::test]
 592async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
 593    let (ep_store, _requests) = init_test_with_fake_client(cx);
 594    let fs = FakeFs::new(cx.executor());
 595    fs.insert_tree(
 596        "/root",
 597        json!({
 598            "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
 599        }),
 600    )
 601    .await;
 602    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 603
 604    let buffer = project
 605        .update(cx, |project, cx| {
 606            let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
 607            project.open_buffer(path, cx)
 608        })
 609        .await
 610        .unwrap();
 611
 612    ep_store.update(cx, |ep_store, cx| {
 613        ep_store.register_buffer(&buffer, &project, cx);
 614    });
 615
 616    // Case 1: Manual edits have `predicted` set to false.
 617    buffer.update(cx, |buffer, cx| {
 618        buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
 619    });
 620
 621    let events = ep_store.update(cx, |ep_store, cx| {
 622        ep_store.edit_history_for_project(&project, cx)
 623    });
 624
 625    assert_eq!(
 626        render_events_with_predicted(&events),
 627        vec![indoc! {"
 628            manual
 629            @@ -1,4 +1,4 @@
 630            -line 0
 631            +LINE ZERO
 632             line 1
 633             line 2
 634             line 3
 635        "}]
 636    );
 637
 638    // Case 2: Multiple successive manual edits near each other are merged into one
 639    // event with `predicted` set to false.
 640    buffer.update(cx, |buffer, cx| {
 641        let offset = Point::new(1, 0).to_offset(buffer);
 642        let end = Point::new(1, 6).to_offset(buffer);
 643        buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
 644    });
 645
 646    let events = ep_store.update(cx, |ep_store, cx| {
 647        ep_store.edit_history_for_project(&project, cx)
 648    });
 649    assert_eq!(
 650        render_events_with_predicted(&events),
 651        vec![indoc! {"
 652            manual
 653            @@ -1,5 +1,5 @@
 654            -line 0
 655            -line 1
 656            +LINE ZERO
 657            +LINE ONE
 658             line 2
 659             line 3
 660             line 4
 661        "}]
 662    );
 663
 664    // Case 3: Accepted predictions have `predicted` set to true.
 665    // Case 5: A manual edit that follows a predicted edit is not merged with the
 666    // predicted edit, even if it is nearby.
 667    ep_store.update(cx, |ep_store, cx| {
 668        buffer.update(cx, |buffer, cx| {
 669            let offset = Point::new(2, 0).to_offset(buffer);
 670            let end = Point::new(2, 6).to_offset(buffer);
 671            buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
 672        });
 673        ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
 674    });
 675
 676    let events = ep_store.update(cx, |ep_store, cx| {
 677        ep_store.edit_history_for_project(&project, cx)
 678    });
 679    assert_eq!(
 680        render_events_with_predicted(&events),
 681        vec![
 682            indoc! {"
 683                manual
 684                @@ -1,5 +1,5 @@
 685                -line 0
 686                -line 1
 687                +LINE ZERO
 688                +LINE ONE
 689                 line 2
 690                 line 3
 691                 line 4
 692            "},
 693            indoc! {"
 694                predicted
 695                @@ -1,6 +1,6 @@
 696                 LINE ZERO
 697                 LINE ONE
 698                -line 2
 699                +LINE TWO
 700                 line 3
 701                 line 4
 702                 line 5
 703            "}
 704        ]
 705    );
 706
 707    // Case 4: Multiple successive accepted predictions near each other are merged
 708    // into one event with `predicted` set to true.
 709    ep_store.update(cx, |ep_store, cx| {
 710        buffer.update(cx, |buffer, cx| {
 711            let offset = Point::new(3, 0).to_offset(buffer);
 712            let end = Point::new(3, 6).to_offset(buffer);
 713            buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
 714        });
 715        ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
 716    });
 717
 718    let events = ep_store.update(cx, |ep_store, cx| {
 719        ep_store.edit_history_for_project(&project, cx)
 720    });
 721    assert_eq!(
 722        render_events_with_predicted(&events),
 723        vec![
 724            indoc! {"
 725                manual
 726                @@ -1,5 +1,5 @@
 727                -line 0
 728                -line 1
 729                +LINE ZERO
 730                +LINE ONE
 731                 line 2
 732                 line 3
 733                 line 4
 734            "},
 735            indoc! {"
 736                predicted
 737                @@ -1,7 +1,7 @@
 738                 LINE ZERO
 739                 LINE ONE
 740                -line 2
 741                -line 3
 742                +LINE TWO
 743                +LINE THREE
 744                 line 4
 745                 line 5
 746                 line 6
 747            "}
 748        ]
 749    );
 750
 751    // Case 5 (continued): A manual edit that follows a predicted edit is not merged
 752    // with the predicted edit, even if it is nearby.
 753    buffer.update(cx, |buffer, cx| {
 754        let offset = Point::new(4, 0).to_offset(buffer);
 755        let end = Point::new(4, 6).to_offset(buffer);
 756        buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
 757    });
 758
 759    let events = ep_store.update(cx, |ep_store, cx| {
 760        ep_store.edit_history_for_project(&project, cx)
 761    });
 762    assert_eq!(
 763        render_events_with_predicted(&events),
 764        vec![
 765            indoc! {"
 766                manual
 767                @@ -1,5 +1,5 @@
 768                -line 0
 769                -line 1
 770                +LINE ZERO
 771                +LINE ONE
 772                 line 2
 773                 line 3
 774                 line 4
 775            "},
 776            indoc! {"
 777                predicted
 778                @@ -1,7 +1,7 @@
 779                 LINE ZERO
 780                 LINE ONE
 781                -line 2
 782                -line 3
 783                +LINE TWO
 784                +LINE THREE
 785                 line 4
 786                 line 5
 787                 line 6
 788            "},
 789            indoc! {"
 790                manual
 791                @@ -2,7 +2,7 @@
 792                 LINE ONE
 793                 LINE TWO
 794                 LINE THREE
 795                -line 4
 796                +LINE FOUR
 797                 line 5
 798                 line 6
 799                 line 7
 800            "}
 801        ]
 802    );
 803
 804    // Case 6: If we then perform a manual edit at a *different* location (more than
 805    // 8 lines away), then the edits at the prior location can be merged with each
 806    // other, even if some are predicted and some are not. `predicted` means all
 807    // constituent edits were predicted.
 808    buffer.update(cx, |buffer, cx| {
 809        let offset = Point::new(14, 0).to_offset(buffer);
 810        let end = Point::new(14, 7).to_offset(buffer);
 811        buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
 812    });
 813
 814    let events = ep_store.update(cx, |ep_store, cx| {
 815        ep_store.edit_history_for_project(&project, cx)
 816    });
 817    assert_eq!(
 818        render_events_with_predicted(&events),
 819        vec![
 820            indoc! {"
 821                manual
 822                @@ -1,8 +1,8 @@
 823                -line 0
 824                -line 1
 825                -line 2
 826                -line 3
 827                -line 4
 828                +LINE ZERO
 829                +LINE ONE
 830                +LINE TWO
 831                +LINE THREE
 832                +LINE FOUR
 833                 line 5
 834                 line 6
 835                 line 7
 836            "},
 837            indoc! {"
 838                manual
 839                @@ -12,4 +12,4 @@
 840                 line 11
 841                 line 12
 842                 line 13
 843                -line 14
 844                +LINE FOURTEEN
 845            "}
 846        ]
 847    );
 848}
 849
 850#[gpui::test]
 851async fn test_empty_prediction(cx: &mut TestAppContext) {
 852    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 853    let fs = FakeFs::new(cx.executor());
 854    fs.insert_tree(
 855        "/root",
 856        json!({
 857            "foo.md":  "Hello!\nHow\nBye\n"
 858        }),
 859    )
 860    .await;
 861    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 862
 863    let buffer = project
 864        .update(cx, |project, cx| {
 865            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 866            project.open_buffer(path, cx)
 867        })
 868        .await
 869        .unwrap();
 870    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 871    let position = snapshot.anchor_before(language::Point::new(1, 3));
 872
 873    ep_store.update(cx, |ep_store, cx| {
 874        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 875    });
 876
 877    let (request, respond_tx) = requests.predict.next().await.unwrap();
 878    let response = model_response(&request, "");
 879    let id = response.request_id.clone();
 880    respond_tx.send(response).unwrap();
 881
 882    cx.run_until_parked();
 883
 884    ep_store.update(cx, |ep_store, cx| {
 885        assert!(
 886            ep_store
 887                .prediction_at(&buffer, None, &project, cx)
 888                .is_none()
 889        );
 890    });
 891
 892    // prediction is reported as rejected
 893    let (reject_request, _) = requests.reject.next().await.unwrap();
 894
 895    assert_eq!(
 896        &reject_request.rejections,
 897        &[EditPredictionRejection {
 898            request_id: id,
 899            reason: EditPredictionRejectReason::Empty,
 900            was_shown: false
 901        }]
 902    );
 903}
 904
 905#[gpui::test]
 906async fn test_interpolated_empty(cx: &mut TestAppContext) {
 907    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 908    let fs = FakeFs::new(cx.executor());
 909    fs.insert_tree(
 910        "/root",
 911        json!({
 912            "foo.md":  "Hello!\nHow\nBye\n"
 913        }),
 914    )
 915    .await;
 916    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 917
 918    let buffer = project
 919        .update(cx, |project, cx| {
 920            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 921            project.open_buffer(path, cx)
 922        })
 923        .await
 924        .unwrap();
 925    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 926    let position = snapshot.anchor_before(language::Point::new(1, 3));
 927
 928    ep_store.update(cx, |ep_store, cx| {
 929        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 930    });
 931
 932    let (request, respond_tx) = requests.predict.next().await.unwrap();
 933
 934    buffer.update(cx, |buffer, cx| {
 935        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 936    });
 937
 938    let response = model_response(&request, SIMPLE_DIFF);
 939    let id = response.request_id.clone();
 940    respond_tx.send(response).unwrap();
 941
 942    cx.run_until_parked();
 943
 944    ep_store.update(cx, |ep_store, cx| {
 945        assert!(
 946            ep_store
 947                .prediction_at(&buffer, None, &project, cx)
 948                .is_none()
 949        );
 950    });
 951
 952    // prediction is reported as rejected
 953    let (reject_request, _) = requests.reject.next().await.unwrap();
 954
 955    assert_eq!(
 956        &reject_request.rejections,
 957        &[EditPredictionRejection {
 958            request_id: id,
 959            reason: EditPredictionRejectReason::InterpolatedEmpty,
 960            was_shown: false
 961        }]
 962    );
 963}
 964
 965const SIMPLE_DIFF: &str = indoc! { r"
 966    --- a/root/foo.md
 967    +++ b/root/foo.md
 968    @@ ... @@
 969     Hello!
 970    -How
 971    +How are you?
 972     Bye
 973"};
 974
 975#[gpui::test]
 976async fn test_replace_current(cx: &mut TestAppContext) {
 977    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 978    let fs = FakeFs::new(cx.executor());
 979    fs.insert_tree(
 980        "/root",
 981        json!({
 982            "foo.md":  "Hello!\nHow\nBye\n"
 983        }),
 984    )
 985    .await;
 986    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 987
 988    let buffer = project
 989        .update(cx, |project, cx| {
 990            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 991            project.open_buffer(path, cx)
 992        })
 993        .await
 994        .unwrap();
 995    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 996    let position = snapshot.anchor_before(language::Point::new(1, 3));
 997
 998    ep_store.update(cx, |ep_store, cx| {
 999        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1000    });
1001
1002    let (request, respond_tx) = requests.predict.next().await.unwrap();
1003    let first_response = model_response(&request, SIMPLE_DIFF);
1004    let first_id = first_response.request_id.clone();
1005    respond_tx.send(first_response).unwrap();
1006
1007    cx.run_until_parked();
1008
1009    ep_store.update(cx, |ep_store, cx| {
1010        assert_eq!(
1011            ep_store
1012                .prediction_at(&buffer, None, &project, cx)
1013                .unwrap()
1014                .id
1015                .0,
1016            first_id
1017        );
1018    });
1019
1020    // a second request is triggered
1021    ep_store.update(cx, |ep_store, cx| {
1022        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1023    });
1024
1025    let (request, respond_tx) = requests.predict.next().await.unwrap();
1026    let second_response = model_response(&request, SIMPLE_DIFF);
1027    let second_id = second_response.request_id.clone();
1028    respond_tx.send(second_response).unwrap();
1029
1030    cx.run_until_parked();
1031
1032    ep_store.update(cx, |ep_store, cx| {
1033        // second replaces first
1034        assert_eq!(
1035            ep_store
1036                .prediction_at(&buffer, None, &project, cx)
1037                .unwrap()
1038                .id
1039                .0,
1040            second_id
1041        );
1042    });
1043
1044    // first is reported as replaced
1045    let (reject_request, _) = requests.reject.next().await.unwrap();
1046
1047    assert_eq!(
1048        &reject_request.rejections,
1049        &[EditPredictionRejection {
1050            request_id: first_id,
1051            reason: EditPredictionRejectReason::Replaced,
1052            was_shown: false
1053        }]
1054    );
1055}
1056
1057#[gpui::test]
1058async fn test_current_preferred(cx: &mut TestAppContext) {
1059    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1060    let fs = FakeFs::new(cx.executor());
1061    fs.insert_tree(
1062        "/root",
1063        json!({
1064            "foo.md":  "Hello!\nHow\nBye\n"
1065        }),
1066    )
1067    .await;
1068    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1069
1070    let buffer = project
1071        .update(cx, |project, cx| {
1072            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1073            project.open_buffer(path, cx)
1074        })
1075        .await
1076        .unwrap();
1077    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1078    let position = snapshot.anchor_before(language::Point::new(1, 3));
1079
1080    ep_store.update(cx, |ep_store, cx| {
1081        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1082    });
1083
1084    let (request, respond_tx) = requests.predict.next().await.unwrap();
1085    let first_response = model_response(&request, SIMPLE_DIFF);
1086    let first_id = first_response.request_id.clone();
1087    respond_tx.send(first_response).unwrap();
1088
1089    cx.run_until_parked();
1090
1091    ep_store.update(cx, |ep_store, cx| {
1092        assert_eq!(
1093            ep_store
1094                .prediction_at(&buffer, None, &project, cx)
1095                .unwrap()
1096                .id
1097                .0,
1098            first_id
1099        );
1100    });
1101
1102    // a second request is triggered
1103    ep_store.update(cx, |ep_store, cx| {
1104        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1105    });
1106
1107    let (request, respond_tx) = requests.predict.next().await.unwrap();
1108    // worse than current prediction
1109    let second_response = model_response(
1110        &request,
1111        indoc! { r"
1112            --- a/root/foo.md
1113            +++ b/root/foo.md
1114            @@ ... @@
1115             Hello!
1116            -How
1117            +How are
1118             Bye
1119        "},
1120    );
1121    let second_id = second_response.request_id.clone();
1122    respond_tx.send(second_response).unwrap();
1123
1124    cx.run_until_parked();
1125
1126    ep_store.update(cx, |ep_store, cx| {
1127        // first is preferred over second
1128        assert_eq!(
1129            ep_store
1130                .prediction_at(&buffer, None, &project, cx)
1131                .unwrap()
1132                .id
1133                .0,
1134            first_id
1135        );
1136    });
1137
1138    // second is reported as rejected
1139    let (reject_request, _) = requests.reject.next().await.unwrap();
1140
1141    assert_eq!(
1142        &reject_request.rejections,
1143        &[EditPredictionRejection {
1144            request_id: second_id,
1145            reason: EditPredictionRejectReason::CurrentPreferred,
1146            was_shown: false
1147        }]
1148    );
1149}
1150
1151#[gpui::test]
1152async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1153    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1154    let fs = FakeFs::new(cx.executor());
1155    fs.insert_tree(
1156        "/root",
1157        json!({
1158            "foo.md":  "Hello!\nHow\nBye\n"
1159        }),
1160    )
1161    .await;
1162    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1163
1164    let buffer = project
1165        .update(cx, |project, cx| {
1166            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1167            project.open_buffer(path, cx)
1168        })
1169        .await
1170        .unwrap();
1171    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1172    let position = snapshot.anchor_before(language::Point::new(1, 3));
1173
1174    // start two refresh tasks
1175    ep_store.update(cx, |ep_store, cx| {
1176        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1177    });
1178
1179    let (request1, respond_first) = requests.predict.next().await.unwrap();
1180
1181    ep_store.update(cx, |ep_store, cx| {
1182        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1183    });
1184
1185    let (request, respond_second) = requests.predict.next().await.unwrap();
1186
1187    // wait for throttle
1188    cx.run_until_parked();
1189
1190    // second responds first
1191    let second_response = model_response(&request, SIMPLE_DIFF);
1192    let second_id = second_response.request_id.clone();
1193    respond_second.send(second_response).unwrap();
1194
1195    cx.run_until_parked();
1196
1197    ep_store.update(cx, |ep_store, cx| {
1198        // current prediction is second
1199        assert_eq!(
1200            ep_store
1201                .prediction_at(&buffer, None, &project, cx)
1202                .unwrap()
1203                .id
1204                .0,
1205            second_id
1206        );
1207    });
1208
1209    let first_response = model_response(&request1, SIMPLE_DIFF);
1210    let first_id = first_response.request_id.clone();
1211    respond_first.send(first_response).unwrap();
1212
1213    cx.run_until_parked();
1214
1215    ep_store.update(cx, |ep_store, cx| {
1216        // current prediction is still second, since first was cancelled
1217        assert_eq!(
1218            ep_store
1219                .prediction_at(&buffer, None, &project, cx)
1220                .unwrap()
1221                .id
1222                .0,
1223            second_id
1224        );
1225    });
1226
1227    // first is reported as rejected
1228    let (reject_request, _) = requests.reject.next().await.unwrap();
1229
1230    cx.run_until_parked();
1231
1232    assert_eq!(
1233        &reject_request.rejections,
1234        &[EditPredictionRejection {
1235            request_id: first_id,
1236            reason: EditPredictionRejectReason::Canceled,
1237            was_shown: false
1238        }]
1239    );
1240}
1241
1242#[gpui::test]
1243async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1244    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1245    let fs = FakeFs::new(cx.executor());
1246    fs.insert_tree(
1247        "/root",
1248        json!({
1249            "foo.md":  "Hello!\nHow\nBye\n"
1250        }),
1251    )
1252    .await;
1253    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1254
1255    let buffer = project
1256        .update(cx, |project, cx| {
1257            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1258            project.open_buffer(path, cx)
1259        })
1260        .await
1261        .unwrap();
1262    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1263    let position = snapshot.anchor_before(language::Point::new(1, 3));
1264
1265    // start two refresh tasks
1266    ep_store.update(cx, |ep_store, cx| {
1267        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1268    });
1269
1270    let (request1, respond_first) = requests.predict.next().await.unwrap();
1271
1272    ep_store.update(cx, |ep_store, cx| {
1273        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1274    });
1275
1276    let (request2, respond_second) = requests.predict.next().await.unwrap();
1277
1278    // wait for throttle, so requests are sent
1279    cx.run_until_parked();
1280
1281    ep_store.update(cx, |ep_store, cx| {
1282        // start a third request
1283        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1284
1285        // 2 are pending, so 2nd is cancelled
1286        assert_eq!(
1287            ep_store
1288                .get_or_init_project(&project, cx)
1289                .cancelled_predictions
1290                .iter()
1291                .copied()
1292                .collect::<Vec<_>>(),
1293            [1]
1294        );
1295    });
1296
1297    // wait for throttle
1298    cx.run_until_parked();
1299
1300    let (request3, respond_third) = requests.predict.next().await.unwrap();
1301
1302    let first_response = model_response(&request1, SIMPLE_DIFF);
1303    let first_id = first_response.request_id.clone();
1304    respond_first.send(first_response).unwrap();
1305
1306    cx.run_until_parked();
1307
1308    ep_store.update(cx, |ep_store, cx| {
1309        // current prediction is first
1310        assert_eq!(
1311            ep_store
1312                .prediction_at(&buffer, None, &project, cx)
1313                .unwrap()
1314                .id
1315                .0,
1316            first_id
1317        );
1318    });
1319
1320    let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1321    let cancelled_id = cancelled_response.request_id.clone();
1322    respond_second.send(cancelled_response).unwrap();
1323
1324    cx.run_until_parked();
1325
1326    ep_store.update(cx, |ep_store, cx| {
1327        // current prediction is still first, since second was cancelled
1328        assert_eq!(
1329            ep_store
1330                .prediction_at(&buffer, None, &project, cx)
1331                .unwrap()
1332                .id
1333                .0,
1334            first_id
1335        );
1336    });
1337
1338    let third_response = model_response(&request3, SIMPLE_DIFF);
1339    let third_response_id = third_response.request_id.clone();
1340    respond_third.send(third_response).unwrap();
1341
1342    cx.run_until_parked();
1343
1344    ep_store.update(cx, |ep_store, cx| {
1345        // third completes and replaces first
1346        assert_eq!(
1347            ep_store
1348                .prediction_at(&buffer, None, &project, cx)
1349                .unwrap()
1350                .id
1351                .0,
1352            third_response_id
1353        );
1354    });
1355
1356    // second is reported as rejected
1357    let (reject_request, _) = requests.reject.next().await.unwrap();
1358
1359    cx.run_until_parked();
1360
1361    assert_eq!(
1362        &reject_request.rejections,
1363        &[
1364            EditPredictionRejection {
1365                request_id: cancelled_id,
1366                reason: EditPredictionRejectReason::Canceled,
1367                was_shown: false
1368            },
1369            EditPredictionRejection {
1370                request_id: first_id,
1371                reason: EditPredictionRejectReason::Replaced,
1372                was_shown: false
1373            }
1374        ]
1375    );
1376}
1377
1378#[gpui::test]
1379async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
1380    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1381
1382    let fs = FakeFs::new(cx.executor());
1383    fs.insert_tree(
1384        "/root",
1385        json!({
1386            "foo.md":  "Hello!\nHow\nBye\n",
1387            "bar.md": "Hola!\nComo\nAdios\n"
1388        }),
1389    )
1390    .await;
1391    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1392
1393    let buffer = project
1394        .update(cx, |project, cx| {
1395            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1396            project.set_active_path(Some(path.clone()), cx);
1397            project.open_buffer(path, cx)
1398        })
1399        .await
1400        .unwrap();
1401    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1402    let position = snapshot.anchor_before(language::Point::new(1, 3));
1403
1404    ep_store.update(cx, |ep_store, cx| {
1405        ep_store.register_project(&project, cx);
1406        ep_store.register_buffer(&buffer, &project, cx);
1407    });
1408
1409    // First edit request - no prior edit, so not throttled.
1410    ep_store.update(cx, |ep_store, cx| {
1411        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1412    });
1413    let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
1414    edit_response_tx.send(empty_response()).unwrap();
1415    cx.run_until_parked();
1416
1417    let diagnostic = lsp::Diagnostic {
1418        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1419        severity: Some(lsp::DiagnosticSeverity::ERROR),
1420        message: "Sentence is incomplete".to_string(),
1421        ..Default::default()
1422    };
1423
1424    // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
1425    project.update(cx, |project, cx| {
1426        project.lsp_store().update(cx, |lsp_store, cx| {
1427            lsp_store
1428                .update_diagnostics(
1429                    LanguageServerId(0),
1430                    lsp::PublishDiagnosticsParams {
1431                        uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
1432                        diagnostics: vec![diagnostic],
1433                        version: None,
1434                    },
1435                    None,
1436                    language::DiagnosticSourceKind::Pushed,
1437                    &[],
1438                    cx,
1439                )
1440                .unwrap();
1441        });
1442    });
1443    let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
1444    jump_response_tx.send(empty_response()).unwrap();
1445    cx.run_until_parked();
1446
1447    // Second edit request - should be throttled by the first edit.
1448    ep_store.update(cx, |ep_store, cx| {
1449        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1450    });
1451    assert_no_predict_request_ready(&mut requests.predict);
1452
1453    // Second jump request - should be throttled by the first jump.
1454    ep_store.update(cx, |ep_store, cx| {
1455        ep_store.refresh_prediction_from_diagnostics(
1456            project.clone(),
1457            DiagnosticSearchScope::Global,
1458            cx,
1459        );
1460    });
1461    assert_no_predict_request_ready(&mut requests.predict);
1462
1463    // Wait for both throttles to expire.
1464    cx.background_executor
1465        .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
1466    cx.background_executor.run_until_parked();
1467    cx.run_until_parked();
1468
1469    // Both requests should now go through.
1470    let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
1471    response_tx_1.send(empty_response()).unwrap();
1472    cx.run_until_parked();
1473
1474    let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
1475    response_tx_2.send(empty_response()).unwrap();
1476    cx.run_until_parked();
1477}
1478
1479#[gpui::test]
1480async fn test_rejections_flushing(cx: &mut TestAppContext) {
1481    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1482
1483    ep_store.update(cx, |ep_store, cx| {
1484        ep_store.reject_prediction(
1485            EditPredictionId("test-1".into()),
1486            EditPredictionRejectReason::Discarded,
1487            false,
1488            cx,
1489        );
1490        ep_store.reject_prediction(
1491            EditPredictionId("test-2".into()),
1492            EditPredictionRejectReason::Canceled,
1493            true,
1494            cx,
1495        );
1496    });
1497
1498    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1499    cx.run_until_parked();
1500
1501    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1502    respond_tx.send(()).unwrap();
1503
1504    // batched
1505    assert_eq!(reject_request.rejections.len(), 2);
1506    assert_eq!(
1507        reject_request.rejections[0],
1508        EditPredictionRejection {
1509            request_id: "test-1".to_string(),
1510            reason: EditPredictionRejectReason::Discarded,
1511            was_shown: false
1512        }
1513    );
1514    assert_eq!(
1515        reject_request.rejections[1],
1516        EditPredictionRejection {
1517            request_id: "test-2".to_string(),
1518            reason: EditPredictionRejectReason::Canceled,
1519            was_shown: true
1520        }
1521    );
1522
1523    // Reaching batch size limit sends without debounce
1524    ep_store.update(cx, |ep_store, cx| {
1525        for i in 0..70 {
1526            ep_store.reject_prediction(
1527                EditPredictionId(format!("batch-{}", i).into()),
1528                EditPredictionRejectReason::Discarded,
1529                false,
1530                cx,
1531            );
1532        }
1533    });
1534
1535    // First MAX/2 items are sent immediately
1536    cx.run_until_parked();
1537    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1538    respond_tx.send(()).unwrap();
1539
1540    assert_eq!(reject_request.rejections.len(), 50);
1541    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1542    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1543
1544    // Remaining items are debounced with the next batch
1545    cx.executor().advance_clock(Duration::from_secs(15));
1546    cx.run_until_parked();
1547
1548    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1549    respond_tx.send(()).unwrap();
1550
1551    assert_eq!(reject_request.rejections.len(), 20);
1552    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1553    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1554
1555    // Request failure
1556    ep_store.update(cx, |ep_store, cx| {
1557        ep_store.reject_prediction(
1558            EditPredictionId("retry-1".into()),
1559            EditPredictionRejectReason::Discarded,
1560            false,
1561            cx,
1562        );
1563    });
1564
1565    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1566    cx.run_until_parked();
1567
1568    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1569    assert_eq!(reject_request.rejections.len(), 1);
1570    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1571    // Simulate failure
1572    drop(_respond_tx);
1573
1574    // Add another rejection
1575    ep_store.update(cx, |ep_store, cx| {
1576        ep_store.reject_prediction(
1577            EditPredictionId("retry-2".into()),
1578            EditPredictionRejectReason::Discarded,
1579            false,
1580            cx,
1581        );
1582    });
1583
1584    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1585    cx.run_until_parked();
1586
1587    // Retry should include both the failed item and the new one
1588    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1589    respond_tx.send(()).unwrap();
1590
1591    assert_eq!(reject_request.rejections.len(), 2);
1592    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1593    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1594}
1595
1596// Skipped until we start including diagnostics in prompt
1597// #[gpui::test]
1598// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1599//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1600//     let fs = FakeFs::new(cx.executor());
1601//     fs.insert_tree(
1602//         "/root",
1603//         json!({
1604//             "foo.md": "Hello!\nBye"
1605//         }),
1606//     )
1607//     .await;
1608//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1609
1610//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1611//     let diagnostic = lsp::Diagnostic {
1612//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1613//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1614//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1615//         ..Default::default()
1616//     };
1617
1618//     project.update(cx, |project, cx| {
1619//         project.lsp_store().update(cx, |lsp_store, cx| {
1620//             // Create some diagnostics
1621//             lsp_store
1622//                 .update_diagnostics(
1623//                     LanguageServerId(0),
1624//                     lsp::PublishDiagnosticsParams {
1625//                         uri: path_to_buffer_uri.clone(),
1626//                         diagnostics: vec![diagnostic],
1627//                         version: None,
1628//                     },
1629//                     None,
1630//                     language::DiagnosticSourceKind::Pushed,
1631//                     &[],
1632//                     cx,
1633//                 )
1634//                 .unwrap();
1635//         });
1636//     });
1637
1638//     let buffer = project
1639//         .update(cx, |project, cx| {
1640//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1641//             project.open_buffer(path, cx)
1642//         })
1643//         .await
1644//         .unwrap();
1645
1646//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1647//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1648
1649//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1650//         ep_store.request_prediction(&project, &buffer, position, cx)
1651//     });
1652
1653//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1654
1655//     assert_eq!(request.diagnostic_groups.len(), 1);
1656//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1657//         .unwrap();
1658//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1659//     assert_eq!(
1660//         value,
1661//         json!({
1662//             "entries": [{
1663//                 "range": {
1664//                     "start": 8,
1665//                     "end": 10
1666//                 },
1667//                 "diagnostic": {
1668//                     "source": null,
1669//                     "code": null,
1670//                     "code_description": null,
1671//                     "severity": 1,
1672//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1673//                     "markdown": null,
1674//                     "group_id": 0,
1675//                     "is_primary": true,
1676//                     "is_disk_based": false,
1677//                     "is_unnecessary": false,
1678//                     "source_kind": "Pushed",
1679//                     "data": null,
1680//                     "underline": true
1681//                 }
1682//             }],
1683//             "primary_ix": 0
1684//         })
1685//     );
1686// }
1687
1688// Generate a model response that would apply the given diff to the active file.
1689fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1690    let editable_range = request
1691        .input
1692        .excerpt_ranges
1693        .as_ref()
1694        .map(|r| zeta_prompt::excerpt_range_for_format(Default::default(), r).1)
1695        .unwrap_or(request.input.editable_range_in_excerpt.clone());
1696    let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string();
1697    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1698
1699    PredictEditsV3Response {
1700        request_id: Uuid::new_v4().to_string(),
1701        editable_range,
1702        output: new_excerpt,
1703    }
1704}
1705
1706fn empty_response() -> PredictEditsV3Response {
1707    PredictEditsV3Response {
1708        request_id: Uuid::new_v4().to_string(),
1709        editable_range: 0..0,
1710        output: String::new(),
1711    }
1712}
1713
1714fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1715    zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1716}
1717
1718fn assert_no_predict_request_ready(
1719    requests: &mut mpsc::UnboundedReceiver<(
1720        PredictEditsV3Request,
1721        oneshot::Sender<PredictEditsV3Response>,
1722    )>,
1723) {
1724    if requests.next().now_or_never().flatten().is_some() {
1725        panic!("Unexpected prediction request while throttled.");
1726    }
1727}
1728
1729struct RequestChannels {
1730    predict: mpsc::UnboundedReceiver<(
1731        PredictEditsV3Request,
1732        oneshot::Sender<PredictEditsV3Response>,
1733    )>,
1734    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1735}
1736
1737fn init_test_with_fake_client(
1738    cx: &mut TestAppContext,
1739) -> (Entity<EditPredictionStore>, RequestChannels) {
1740    cx.update(move |cx| {
1741        let settings_store = SettingsStore::test(cx);
1742        cx.set_global(settings_store);
1743        zlog::init_test();
1744
1745        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1746        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1747
1748        let http_client = FakeHttpClient::create({
1749            move |req| {
1750                let uri = req.uri().path().to_string();
1751                let mut body = req.into_body();
1752                let predict_req_tx = predict_req_tx.clone();
1753                let reject_req_tx = reject_req_tx.clone();
1754                async move {
1755                    let resp = match uri.as_str() {
1756                        "/client/llm_tokens" => serde_json::to_string(&json!({
1757                            "token": "test"
1758                        }))
1759                        .unwrap(),
1760                        "/predict_edits/v3" => {
1761                            let mut buf = Vec::new();
1762                            body.read_to_end(&mut buf).await.ok();
1763                            let decompressed = zstd::decode_all(&buf[..]).unwrap();
1764                            let req = serde_json::from_slice(&decompressed).unwrap();
1765
1766                            let (res_tx, res_rx) = oneshot::channel();
1767                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1768                            serde_json::to_string(&res_rx.await?).unwrap()
1769                        }
1770                        "/predict_edits/reject" => {
1771                            let mut buf = Vec::new();
1772                            body.read_to_end(&mut buf).await.ok();
1773                            let req = serde_json::from_slice(&buf).unwrap();
1774
1775                            let (res_tx, res_rx) = oneshot::channel();
1776                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1777                            serde_json::to_string(&res_rx.await?).unwrap()
1778                        }
1779                        _ => {
1780                            panic!("Unexpected path: {}", uri)
1781                        }
1782                    };
1783
1784                    Ok(Response::builder().body(resp.into()).unwrap())
1785                }
1786            }
1787        });
1788
1789        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1790        client.cloud_client().set_credentials(1, "test".into());
1791
1792        language_model::init(client.clone(), cx);
1793
1794        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1795        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1796
1797        (
1798            ep_store,
1799            RequestChannels {
1800                predict: predict_req_rx,
1801                reject: reject_req_rx,
1802            },
1803        )
1804    })
1805}
1806
1807#[gpui::test]
1808async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1809    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1810    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1811        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1812    });
1813
1814    let edit_preview = cx
1815        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1816        .await;
1817
1818    let prediction = EditPrediction {
1819        edits,
1820        cursor_position: None,
1821        edit_preview,
1822        buffer: buffer.clone(),
1823        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1824        id: EditPredictionId("the-id".into()),
1825        inputs: ZetaPromptInput {
1826            events: Default::default(),
1827            related_files: Default::default(),
1828            cursor_path: Path::new("").into(),
1829            cursor_excerpt: "".into(),
1830            editable_range_in_excerpt: 0..0,
1831            cursor_offset_in_excerpt: 0,
1832            excerpt_start_row: None,
1833            excerpt_ranges: None,
1834            preferred_model: None,
1835            in_open_source_repo: false,
1836            can_collect_data: false,
1837        },
1838        buffer_snapshotted_at: Instant::now(),
1839        response_received_at: Instant::now(),
1840    };
1841
1842    cx.update(|cx| {
1843        assert_eq!(
1844            from_completion_edits(
1845                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1846                &buffer,
1847                cx
1848            ),
1849            vec![(2..5, "REM".into()), (9..11, "".into())]
1850        );
1851
1852        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1853        assert_eq!(
1854            from_completion_edits(
1855                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1856                &buffer,
1857                cx
1858            ),
1859            vec![(2..2, "REM".into()), (6..8, "".into())]
1860        );
1861
1862        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1863        assert_eq!(
1864            from_completion_edits(
1865                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1866                &buffer,
1867                cx
1868            ),
1869            vec![(2..5, "REM".into()), (9..11, "".into())]
1870        );
1871
1872        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1873        assert_eq!(
1874            from_completion_edits(
1875                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1876                &buffer,
1877                cx
1878            ),
1879            vec![(3..3, "EM".into()), (7..9, "".into())]
1880        );
1881
1882        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1883        assert_eq!(
1884            from_completion_edits(
1885                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1886                &buffer,
1887                cx
1888            ),
1889            vec![(4..4, "M".into()), (8..10, "".into())]
1890        );
1891
1892        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1893        assert_eq!(
1894            from_completion_edits(
1895                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1896                &buffer,
1897                cx
1898            ),
1899            vec![(9..11, "".into())]
1900        );
1901
1902        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1903        assert_eq!(
1904            from_completion_edits(
1905                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1906                &buffer,
1907                cx
1908            ),
1909            vec![(4..4, "M".into()), (8..10, "".into())]
1910        );
1911
1912        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1913        assert_eq!(
1914            from_completion_edits(
1915                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1916                &buffer,
1917                cx
1918            ),
1919            vec![(4..4, "M".into())]
1920        );
1921
1922        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1923        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1924    })
1925}
1926
1927#[gpui::test]
1928async fn test_clean_up_diff(cx: &mut TestAppContext) {
1929    init_test(cx);
1930
1931    assert_eq!(
1932        apply_edit_prediction(
1933            indoc! {"
1934                    fn main() {
1935                        let word_1 = \"lorem\";
1936                        let range = word.len()..word.len();
1937                    }
1938                "},
1939            indoc! {"
1940                    fn main() {
1941                        let word_1 = \"lorem\";
1942                        let range = word_1.len()..word_1.len();
1943                    }
1944                "},
1945            cx,
1946        )
1947        .await,
1948        indoc! {"
1949                fn main() {
1950                    let word_1 = \"lorem\";
1951                    let range = word_1.len()..word_1.len();
1952                }
1953            "},
1954    );
1955
1956    assert_eq!(
1957        apply_edit_prediction(
1958            indoc! {"
1959                    fn main() {
1960                        let story = \"the quick\"
1961                    }
1962                "},
1963            indoc! {"
1964                    fn main() {
1965                        let story = \"the quick brown fox jumps over the lazy dog\";
1966                    }
1967                "},
1968            cx,
1969        )
1970        .await,
1971        indoc! {"
1972                fn main() {
1973                    let story = \"the quick brown fox jumps over the lazy dog\";
1974                }
1975            "},
1976    );
1977}
1978
1979#[gpui::test]
1980async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1981    init_test(cx);
1982
1983    let buffer_content = "lorem\n";
1984    let completion_response = "lorem\nipsum\n";
1985
1986    assert_eq!(
1987        apply_edit_prediction(buffer_content, completion_response, cx).await,
1988        "lorem\nipsum\n"
1989    );
1990}
1991
1992#[gpui::test]
1993async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1994    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1995    // When the buffer ends without a trailing newline, but the model returns output
1996    // with a trailing newline, zeta2 should normalize both sides before diffing
1997    // so no spurious newline is inserted.
1998    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1999    let fs = FakeFs::new(cx.executor());
2000
2001    // Single line buffer with no trailing newline
2002    fs.insert_tree(
2003        "/root",
2004        json!({
2005            "foo.txt": "hello"
2006        }),
2007    )
2008    .await;
2009    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2010
2011    let buffer = project
2012        .update(cx, |project, cx| {
2013            let path = project
2014                .find_project_path(path!("root/foo.txt"), cx)
2015                .unwrap();
2016            project.open_buffer(path, cx)
2017        })
2018        .await
2019        .unwrap();
2020
2021    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2022    let position = snapshot.anchor_before(language::Point::new(0, 5));
2023
2024    ep_store.update(cx, |ep_store, cx| {
2025        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2026    });
2027
2028    let (request, respond_tx) = requests.predict.next().await.unwrap();
2029
2030    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
2031    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
2032    let excerpt_length = request.input.cursor_excerpt.len();
2033    let response = PredictEditsV3Response {
2034        request_id: Uuid::new_v4().to_string(),
2035        output: "hello world\n".to_string(),
2036        editable_range: 0..excerpt_length,
2037    };
2038    respond_tx.send(response).unwrap();
2039
2040    cx.run_until_parked();
2041
2042    // The prediction should insert " world" without adding a newline
2043    ep_store.update(cx, |ep_store, cx| {
2044        let prediction = ep_store
2045            .prediction_at(&buffer, None, &project, cx)
2046            .expect("should have prediction");
2047        let edits: Vec<_> = prediction
2048            .edits
2049            .iter()
2050            .map(|(range, text)| {
2051                let snapshot = buffer.read(cx).snapshot();
2052                (range.to_offset(&snapshot), text.clone())
2053            })
2054            .collect();
2055        assert_eq!(edits, vec![(5..5, " world".into())]);
2056    });
2057}
2058
2059fn init_test(cx: &mut TestAppContext) {
2060    cx.update(|cx| {
2061        let settings_store = SettingsStore::test(cx);
2062        cx.set_global(settings_store);
2063    });
2064}
2065
2066async fn apply_edit_prediction(
2067    buffer_content: &str,
2068    completion_response: &str,
2069    cx: &mut TestAppContext,
2070) -> String {
2071    let fs = project::FakeFs::new(cx.executor());
2072    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2073    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2074    let (ep_store, response) = make_test_ep_store(&project, cx).await;
2075    *response.lock() = completion_response.to_string();
2076    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2077    buffer.update(cx, |buffer, cx| {
2078        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2079    });
2080    buffer.read_with(cx, |buffer, _| buffer.text())
2081}
2082
2083async fn run_edit_prediction(
2084    buffer: &Entity<Buffer>,
2085    project: &Entity<Project>,
2086    ep_store: &Entity<EditPredictionStore>,
2087    cx: &mut TestAppContext,
2088) -> EditPrediction {
2089    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2090    ep_store.update(cx, |ep_store, cx| {
2091        ep_store.register_buffer(buffer, &project, cx)
2092    });
2093    cx.background_executor.run_until_parked();
2094    let prediction_task = ep_store.update(cx, |ep_store, cx| {
2095        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2096    });
2097    prediction_task.await.unwrap().unwrap().prediction.unwrap()
2098}
2099
2100async fn make_test_ep_store(
2101    project: &Entity<Project>,
2102    cx: &mut TestAppContext,
2103) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
2104    let default_response = "hello world\n".to_string();
2105    let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
2106    let http_client = FakeHttpClient::create({
2107        let completion_response = completion_response.clone();
2108        let mut next_request_id = 0;
2109        move |req| {
2110            let completion_response = completion_response.clone();
2111            let method = req.method().clone();
2112            let uri = req.uri().path().to_string();
2113            let mut body = req.into_body();
2114            async move {
2115                match (method, uri.as_str()) {
2116                    (Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2117                        .status(200)
2118                        .body(
2119                            serde_json::to_string(&CreateLlmTokenResponse {
2120                                token: LlmToken("the-llm-token".to_string()),
2121                            })
2122                            .unwrap()
2123                            .into(),
2124                        )
2125                        .unwrap()),
2126                    (Method::POST, "/predict_edits/v3") => {
2127                        let mut buf = Vec::new();
2128                        body.read_to_end(&mut buf).await.ok();
2129                        let decompressed = zstd::decode_all(&buf[..]).unwrap();
2130                        let req: PredictEditsV3Request =
2131                            serde_json::from_slice(&decompressed).unwrap();
2132
2133                        next_request_id += 1;
2134                        Ok(http_client::Response::builder()
2135                            .status(200)
2136                            .body(
2137                                serde_json::to_string(&PredictEditsV3Response {
2138                                    request_id: format!("request-{next_request_id}"),
2139                                    editable_range: 0..req.input.cursor_excerpt.len(),
2140                                    output: completion_response.lock().clone(),
2141                                })
2142                                .unwrap()
2143                                .into(),
2144                            )
2145                            .unwrap())
2146                    }
2147                    _ => Ok(http_client::Response::builder()
2148                        .status(404)
2149                        .body("Not Found".to_string().into())
2150                        .unwrap()),
2151                }
2152            }
2153        }
2154    });
2155
2156    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2157    cx.update(|cx| {
2158        RefreshLlmTokenListener::register(client.clone(), cx);
2159    });
2160    let _server = FakeServer::for_client(42, &client, cx).await;
2161
2162    let ep_store = cx.new(|cx| {
2163        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2164        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2165
2166        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2167        for worktree in worktrees {
2168            let worktree_id = worktree.read(cx).id();
2169            ep_store
2170                .get_or_init_project(project, cx)
2171                .license_detection_watchers
2172                .entry(worktree_id)
2173                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2174        }
2175
2176        ep_store
2177    });
2178
2179    (ep_store, completion_response)
2180}
2181
2182fn to_completion_edits(
2183    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2184    buffer: &Entity<Buffer>,
2185    cx: &App,
2186) -> Vec<(Range<Anchor>, Arc<str>)> {
2187    let buffer = buffer.read(cx);
2188    iterator
2189        .into_iter()
2190        .map(|(range, text)| {
2191            (
2192                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2193                text,
2194            )
2195        })
2196        .collect()
2197}
2198
2199fn from_completion_edits(
2200    editor_edits: &[(Range<Anchor>, Arc<str>)],
2201    buffer: &Entity<Buffer>,
2202    cx: &App,
2203) -> Vec<(Range<usize>, Arc<str>)> {
2204    let buffer = buffer.read(cx);
2205    editor_edits
2206        .iter()
2207        .map(|(range, text)| {
2208            (
2209                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2210                text.clone(),
2211            )
2212        })
2213        .collect()
2214}
2215
2216#[gpui::test]
2217async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2218    init_test(cx);
2219
2220    let fs = FakeFs::new(cx.executor());
2221    fs.insert_tree(
2222        "/project",
2223        serde_json::json!({
2224            "main.rs": "fn main() {\n    \n}\n"
2225        }),
2226    )
2227    .await;
2228
2229    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2230
2231    let http_client = FakeHttpClient::create(|_req| async move {
2232        Ok(gpui::http_client::Response::builder()
2233            .status(401)
2234            .body("Unauthorized".into())
2235            .unwrap())
2236    });
2237
2238    let client =
2239        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2240    cx.update(|cx| {
2241        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2242    });
2243
2244    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2245
2246    let buffer = project
2247        .update(cx, |project, cx| {
2248            let path = project
2249                .find_project_path(path!("/project/main.rs"), cx)
2250                .unwrap();
2251            project.open_buffer(path, cx)
2252        })
2253        .await
2254        .unwrap();
2255
2256    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2257    ep_store.update(cx, |ep_store, cx| {
2258        ep_store.register_buffer(&buffer, &project, cx)
2259    });
2260    cx.background_executor.run_until_parked();
2261
2262    let completion_task = ep_store.update(cx, |ep_store, cx| {
2263        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2264        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2265    });
2266
2267    let result = completion_task.await;
2268    assert!(
2269        result.is_err(),
2270        "Without authentication and without custom URL, prediction should fail"
2271    );
2272}
2273
2274#[gpui::test]
2275fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2276    let buffer = cx.new(|cx| {
2277        Buffer::local(
2278            indoc! {"
2279                zero
2280                one
2281                two
2282                three
2283                four
2284                five
2285                six
2286                seven
2287                eight
2288                nine
2289                ten
2290                eleven
2291                twelve
2292                thirteen
2293                fourteen
2294                fifteen
2295                sixteen
2296                seventeen
2297                eighteen
2298                nineteen
2299                twenty
2300                twenty-one
2301                twenty-two
2302                twenty-three
2303                twenty-four
2304            "},
2305            cx,
2306        )
2307    });
2308
2309    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2310
2311    buffer.update(cx, |buffer, cx| {
2312        let point = Point::new(12, 0);
2313        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2314        let point = Point::new(8, 0);
2315        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2316    });
2317
2318    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2319
2320    let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2321
2322    assert_eq!(
2323        diff,
2324        indoc! {"
2325            @@ -6,10 +6,12 @@
2326             five
2327             six
2328             seven
2329            +FIRST INSERTION
2330             eight
2331             nine
2332             ten
2333             eleven
2334            +SECOND INSERTION
2335             twelve
2336             thirteen
2337             fourteen
2338            "}
2339    );
2340}
2341
2342#[gpui::test]
2343async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
2344    fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
2345        let collab_replica = clock::ReplicaId::new(10);
2346        let anchor = buffer.read_with(cx, |buffer, _| {
2347            buffer.snapshot().anchor_before(Point::new(row, 0))
2348        });
2349        let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
2350            id: 1,
2351            start: anchor,
2352            end: anchor,
2353            reversed: false,
2354            goal: SelectionGoal::None,
2355        }]);
2356        buffer.update(cx, |buffer, cx| {
2357            buffer.apply_ops(
2358                [Operation::UpdateSelections {
2359                    selections,
2360                    lamport_timestamp: clock::Lamport {
2361                        replica_id: collab_replica,
2362                        value: 1,
2363                    },
2364                    line_mode: false,
2365                    cursor_shape: CursorShape::Bar,
2366                }],
2367                cx,
2368            );
2369        });
2370    }
2371
2372    fn publish_diagnostics(
2373        uri_path: &'static str,
2374        rows: &[u32],
2375        project: &Entity<Project>,
2376        cx: &mut TestAppContext,
2377    ) {
2378        let diagnostics: Vec<_> = rows
2379            .iter()
2380            .map(|&row| lsp::Diagnostic {
2381                range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
2382                severity: Some(lsp::DiagnosticSeverity::ERROR),
2383                message: format!("error at row {row}"),
2384                ..Default::default()
2385            })
2386            .collect();
2387        project.update(cx, |project, cx| {
2388            project.lsp_store().update(cx, |lsp_store, cx| {
2389                lsp_store
2390                    .update_diagnostics(
2391                        LanguageServerId(0),
2392                        lsp::PublishDiagnosticsParams {
2393                            uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
2394                            diagnostics,
2395                            version: None,
2396                        },
2397                        None,
2398                        language::DiagnosticSourceKind::Pushed,
2399                        &[],
2400                        cx,
2401                    )
2402                    .expect("failed to update diagnostics");
2403            });
2404        });
2405    }
2406
2407    init_test(cx);
2408
2409    let mut lines = String::new();
2410    for i in 0..60 {
2411        lines.push_str(&format!("line {i}\n"));
2412    }
2413
2414    let fs = FakeFs::new(cx.executor());
2415    fs.insert_tree(
2416        "/root",
2417        json!({
2418            "active.txt": lines,
2419            "collab_file.txt": "error here\nsecond line\n",
2420            "free_file.txt": "another error\nsecond line\n",
2421        }),
2422    )
2423    .await;
2424    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2425
2426    let active_buffer = project
2427        .update(cx, |project, cx| {
2428            let path = project
2429                .find_project_path(path!("/root/active.txt"), cx)
2430                .expect("active.txt not found");
2431            project.set_active_path(Some(path.clone()), cx);
2432            project.open_buffer(path, cx)
2433        })
2434        .await
2435        .expect("failed to open active buffer");
2436
2437    set_collaborator_cursor(&active_buffer, 5, cx);
2438
2439    publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
2440
2441    cx.run_until_parked();
2442
2443    let cursor_point = Point::new(25, 0);
2444    let empty_search_range: Range<Point> = Default::default();
2445
2446    let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2447    let result = EditPredictionStore::next_diagnostic_location(
2448        active_buffer.clone(),
2449        &snapshot,
2450        empty_search_range.clone(),
2451        cursor_point,
2452        &project,
2453        &mut cx.to_async(),
2454    )
2455    .await
2456    .expect("next_diagnostic_location failed");
2457
2458    let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
2459    assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
2460    let result_row = result_buffer.read_with(cx, |buffer, _| {
2461        result_anchor.to_point(&buffer.snapshot()).row
2462    });
2463    assert_ne!(
2464        result_row, 3,
2465        "row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
2466    );
2467    assert!(
2468        result_row == 25 || result_row == 50,
2469        "expected row 25 or 50, got {result_row}"
2470    );
2471
2472    let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2473    let near_cursor_point = Point::new(4, 0);
2474    let result_near = EditPredictionStore::next_diagnostic_location(
2475        active_buffer.clone(),
2476        &snapshot_near,
2477        empty_search_range.clone(),
2478        near_cursor_point,
2479        &project,
2480        &mut cx.to_async(),
2481    )
2482    .await
2483    .expect("next_diagnostic_location failed");
2484
2485    let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
2486    let near_row =
2487        active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
2488    assert_eq!(
2489        near_row, 3,
2490        "row 3 should be included when local cursor (row 4) is also near the collaborator"
2491    );
2492
2493    let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2494    let far_cursor_point = Point::new(50, 0);
2495    let result_far = EditPredictionStore::next_diagnostic_location(
2496        active_buffer.clone(),
2497        &snapshot_far,
2498        empty_search_range.clone(),
2499        far_cursor_point,
2500        &project,
2501        &mut cx.to_async(),
2502    )
2503    .await
2504    .expect("next_diagnostic_location failed");
2505
2506    let (_, far_anchor) = result_far.expect("expected a diagnostic location");
2507    let far_row =
2508        active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
2509    assert_eq!(
2510        far_row, 50,
2511        "row 50 is near local cursor (row 50) and far from collaborator, should be picked"
2512    );
2513
2514    publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
2515    publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
2516    cx.run_until_parked();
2517
2518    let collab_buffer = project
2519        .update(cx, |project, cx| {
2520            let path = project
2521                .find_project_path(path!("/root/collab_file.txt"), cx)
2522                .expect("collab_file.txt not found");
2523            project.open_buffer(path, cx)
2524        })
2525        .await
2526        .expect("failed to open collab buffer");
2527
2528    set_collaborator_cursor(&collab_buffer, 0, cx);
2529    cx.run_until_parked();
2530
2531    let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
2532    let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2533    let result_cross = EditPredictionStore::next_diagnostic_location(
2534        active_buffer.clone(),
2535        &snapshot_cross,
2536        no_same_file_search_range,
2537        Point::new(0, 0),
2538        &project,
2539        &mut cx.to_async(),
2540    )
2541    .await
2542    .expect("cross-file next_diagnostic_location failed");
2543
2544    let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
2545    let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
2546        buffer
2547            .file()
2548            .expect("buffer should have a file")
2549            .full_path(cx)
2550    });
2551    assert_eq!(
2552        cross_path,
2553        Path::new(path!("root/free_file.txt")),
2554        "should skip collab_file.txt (has collaborator) and pick free_file.txt"
2555    );
2556}
2557
2558#[ctor::ctor]
2559fn init_logger() {
2560    zlog::init_test();
2561}