edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
   3use client::{UserStore, test::FakeServer};
   4use clock::FakeSystemClock;
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
   8    predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
   9};
  10use futures::{
  11    AsyncReadExt, FutureExt, StreamExt,
  12    channel::{mpsc, oneshot},
  13};
  14use gpui::App;
  15use gpui::{
  16    Entity, TestAppContext,
  17    http_client::{FakeHttpClient, Response},
  18};
  19use indoc::indoc;
  20use language::{Anchor, Buffer, CursorShape, Operation, Point, Selection, SelectionGoal};
  21use lsp::LanguageServerId;
  22use parking_lot::Mutex;
  23use pretty_assertions::{assert_eq, assert_matches};
  24use project::{FakeFs, Project};
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{path::Path, sync::Arc, time::Duration};
  28use util::path;
  29use uuid::Uuid;
  30use zeta_prompt::ZetaPromptInput;
  31
  32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  33
  34#[gpui::test]
  35async fn test_current_state(cx: &mut TestAppContext) {
  36    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  37    let fs = FakeFs::new(cx.executor());
  38    fs.insert_tree(
  39        "/root",
  40        json!({
  41            "1.txt": "Hello!\nHow\nBye\n",
  42            "2.txt": "Hola!\nComo\nAdios\n"
  43        }),
  44    )
  45    .await;
  46    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  47
  48    let buffer1 = project
  49        .update(cx, |project, cx| {
  50            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  51            project.set_active_path(Some(path.clone()), cx);
  52            project.open_buffer(path, cx)
  53        })
  54        .await
  55        .unwrap();
  56    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  57    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  58
  59    ep_store.update(cx, |ep_store, cx| {
  60        ep_store.register_project(&project, cx);
  61        ep_store.register_buffer(&buffer1, &project, cx);
  62    });
  63
  64    // Prediction for current file
  65
  66    ep_store.update(cx, |ep_store, cx| {
  67        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  68    });
  69    let (request, respond_tx) = requests.predict.next().await.unwrap();
  70
  71    respond_tx
  72        .send(model_response(
  73            &request,
  74            indoc! {r"
  75                --- a/root/1.txt
  76                +++ b/root/1.txt
  77                @@ ... @@
  78                 Hello!
  79                -How
  80                +How are you?
  81                 Bye
  82            "},
  83        ))
  84        .unwrap();
  85
  86    cx.run_until_parked();
  87
  88    ep_store.update(cx, |ep_store, cx| {
  89        let prediction = ep_store
  90            .prediction_at(&buffer1, None, &project, cx)
  91            .unwrap();
  92        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  93    });
  94
  95    ep_store.update(cx, |ep_store, cx| {
  96        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
  97    });
  98
  99    // Prediction for diagnostic in another file
 100
 101    let diagnostic = lsp::Diagnostic {
 102        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 103        severity: Some(lsp::DiagnosticSeverity::ERROR),
 104        message: "Sentence is incomplete".to_string(),
 105        ..Default::default()
 106    };
 107
 108    project.update(cx, |project, cx| {
 109        project.lsp_store().update(cx, |lsp_store, cx| {
 110            lsp_store
 111                .update_diagnostics(
 112                    LanguageServerId(0),
 113                    lsp::PublishDiagnosticsParams {
 114                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 115                        diagnostics: vec![diagnostic],
 116                        version: None,
 117                    },
 118                    None,
 119                    language::DiagnosticSourceKind::Pushed,
 120                    &[],
 121                    cx,
 122                )
 123                .unwrap();
 124        });
 125    });
 126
 127    let (request, respond_tx) = requests.predict.next().await.unwrap();
 128    respond_tx
 129        .send(model_response(
 130            &request,
 131            indoc! {r#"
 132                --- a/root/2.txt
 133                +++ b/root/2.txt
 134                @@ ... @@
 135                 Hola!
 136                -Como
 137                +Como estas?
 138                 Adios
 139            "#},
 140        ))
 141        .unwrap();
 142    cx.run_until_parked();
 143
 144    ep_store.update(cx, |ep_store, cx| {
 145        let prediction = ep_store
 146            .prediction_at(&buffer1, None, &project, cx)
 147            .unwrap();
 148        assert_matches!(
 149            prediction,
 150            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 151        );
 152    });
 153
 154    let buffer2 = project
 155        .update(cx, |project, cx| {
 156            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 157            project.open_buffer(path, cx)
 158        })
 159        .await
 160        .unwrap();
 161
 162    ep_store.update(cx, |ep_store, cx| {
 163        let prediction = ep_store
 164            .prediction_at(&buffer2, None, &project, cx)
 165            .unwrap();
 166        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 167    });
 168}
 169
 170#[gpui::test]
 171async fn test_simple_request(cx: &mut TestAppContext) {
 172    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 173    let fs = FakeFs::new(cx.executor());
 174    fs.insert_tree(
 175        "/root",
 176        json!({
 177            "foo.md":  "Hello!\nHow\nBye\n"
 178        }),
 179    )
 180    .await;
 181    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 182
 183    let buffer = project
 184        .update(cx, |project, cx| {
 185            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 186            project.open_buffer(path, cx)
 187        })
 188        .await
 189        .unwrap();
 190    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 191    let position = snapshot.anchor_before(language::Point::new(1, 3));
 192
 193    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 194        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 195    });
 196
 197    let (request, respond_tx) = requests.predict.next().await.unwrap();
 198
 199    // TODO Put back when we have a structured request again
 200    // assert_eq!(
 201    //     request.excerpt_path.as_ref(),
 202    //     Path::new(path!("root/foo.md"))
 203    // );
 204    // assert_eq!(
 205    //     request.cursor_point,
 206    //     Point {
 207    //         line: Line(1),
 208    //         column: 3
 209    //     }
 210    // );
 211
 212    respond_tx
 213        .send(model_response(
 214            &request,
 215            indoc! { r"
 216                --- a/root/foo.md
 217                +++ b/root/foo.md
 218                @@ ... @@
 219                 Hello!
 220                -How
 221                +How are you?
 222                 Bye
 223            "},
 224        ))
 225        .unwrap();
 226
 227    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 228
 229    assert_eq!(prediction.edits.len(), 1);
 230    assert_eq!(
 231        prediction.edits[0].0.to_point(&snapshot).start,
 232        language::Point::new(1, 3)
 233    );
 234    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 235}
 236
 237#[gpui::test]
 238async fn test_request_events(cx: &mut TestAppContext) {
 239    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 240    let fs = FakeFs::new(cx.executor());
 241    fs.insert_tree(
 242        "/root",
 243        json!({
 244            "foo.md": "Hello!\n\nBye\n"
 245        }),
 246    )
 247    .await;
 248    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 249
 250    let buffer = project
 251        .update(cx, |project, cx| {
 252            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 253            project.open_buffer(path, cx)
 254        })
 255        .await
 256        .unwrap();
 257
 258    ep_store.update(cx, |ep_store, cx| {
 259        ep_store.register_buffer(&buffer, &project, cx);
 260    });
 261
 262    buffer.update(cx, |buffer, cx| {
 263        buffer.edit(vec![(7..7, "How")], None, cx);
 264    });
 265
 266    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 267    let position = snapshot.anchor_before(language::Point::new(1, 3));
 268
 269    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 270        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 271    });
 272
 273    let (request, respond_tx) = requests.predict.next().await.unwrap();
 274
 275    let prompt = prompt_from_request(&request);
 276    assert!(
 277        prompt.contains(indoc! {"
 278        --- a/root/foo.md
 279        +++ b/root/foo.md
 280        @@ -1,3 +1,3 @@
 281         Hello!
 282        -
 283        +How
 284         Bye
 285    "}),
 286        "{prompt}"
 287    );
 288
 289    respond_tx
 290        .send(model_response(
 291            &request,
 292            indoc! {r#"
 293                --- a/root/foo.md
 294                +++ b/root/foo.md
 295                @@ ... @@
 296                 Hello!
 297                -How
 298                +How are you?
 299                 Bye
 300        "#},
 301        ))
 302        .unwrap();
 303
 304    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 305
 306    assert_eq!(prediction.edits.len(), 1);
 307    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 308}
 309
 310#[gpui::test]
 311async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 312    let (ep_store, _requests) = init_test_with_fake_client(cx);
 313    let fs = FakeFs::new(cx.executor());
 314    fs.insert_tree(
 315        "/root",
 316        json!({
 317            "foo.md": "Hello!\n\nBye\n"
 318        }),
 319    )
 320    .await;
 321    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 322
 323    let buffer = project
 324        .update(cx, |project, cx| {
 325            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 326            project.open_buffer(path, cx)
 327        })
 328        .await
 329        .unwrap();
 330
 331    ep_store.update(cx, |ep_store, cx| {
 332        ep_store.register_buffer(&buffer, &project, cx);
 333    });
 334
 335    // First burst: insert "How"
 336    buffer.update(cx, |buffer, cx| {
 337        buffer.edit(vec![(7..7, "How")], None, cx);
 338    });
 339
 340    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 341    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 342    cx.run_until_parked();
 343
 344    // Second burst: append " are you?" immediately after "How" on the same line.
 345    //
 346    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 347    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 348    buffer.update(cx, |buffer, cx| {
 349        buffer.edit(vec![(10..10, " are you?")], None, cx);
 350    });
 351
 352    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 353    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 354    buffer.update(cx, |buffer, cx| {
 355        buffer.edit(vec![(19..19, "!")], None, cx);
 356    });
 357
 358    // With time-based splitting, there are two distinct events.
 359    let events = ep_store.update(cx, |ep_store, cx| {
 360        ep_store.edit_history_for_project(&project, cx)
 361    });
 362    assert_eq!(events.len(), 2);
 363    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 364    assert_eq!(
 365        diff.as_str(),
 366        indoc! {"
 367            @@ -1,3 +1,3 @@
 368             Hello!
 369            -
 370            +How
 371             Bye
 372        "}
 373    );
 374
 375    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 376    assert_eq!(
 377        diff.as_str(),
 378        indoc! {"
 379            @@ -1,3 +1,3 @@
 380             Hello!
 381            -How
 382            +How are you?!
 383             Bye
 384        "}
 385    );
 386}
 387
 388#[gpui::test]
 389async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
 390    let (ep_store, _requests) = init_test_with_fake_client(cx);
 391    let fs = FakeFs::new(cx.executor());
 392
 393    // Create a file with 30 lines to test line-based coalescing
 394    let content = (1..=30)
 395        .map(|i| format!("Line {}\n", i))
 396        .collect::<String>();
 397    fs.insert_tree(
 398        "/root",
 399        json!({
 400            "foo.md": content
 401        }),
 402    )
 403    .await;
 404    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 405
 406    let buffer = project
 407        .update(cx, |project, cx| {
 408            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 409            project.open_buffer(path, cx)
 410        })
 411        .await
 412        .unwrap();
 413
 414    ep_store.update(cx, |ep_store, cx| {
 415        ep_store.register_buffer(&buffer, &project, cx);
 416    });
 417
 418    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 419    buffer.update(cx, |buffer, cx| {
 420        let start = Point::new(10, 0).to_offset(buffer);
 421        let end = Point::new(13, 0).to_offset(buffer);
 422        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 423    });
 424
 425    let events = ep_store.update(cx, |ep_store, cx| {
 426        ep_store.edit_history_for_project(&project, cx)
 427    });
 428    assert_eq!(
 429        render_events(&events),
 430        indoc! {"
 431            @@ -8,9 +8,8 @@
 432             Line 8
 433             Line 9
 434             Line 10
 435            -Line 11
 436            -Line 12
 437            -Line 13
 438            +Middle A
 439            +Middle B
 440             Line 14
 441             Line 15
 442             Line 16
 443        "},
 444        "After first edit"
 445    );
 446
 447    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 448    // This tests that coalescing considers the START of the existing range
 449    buffer.update(cx, |buffer, cx| {
 450        let offset = Point::new(5, 0).to_offset(buffer);
 451        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 452    });
 453
 454    let events = ep_store.update(cx, |ep_store, cx| {
 455        ep_store.edit_history_for_project(&project, cx)
 456    });
 457    assert_eq!(
 458        render_events(&events),
 459        indoc! {"
 460            @@ -3,14 +3,14 @@
 461             Line 3
 462             Line 4
 463             Line 5
 464            +Above
 465             Line 6
 466             Line 7
 467             Line 8
 468             Line 9
 469             Line 10
 470            -Line 11
 471            -Line 12
 472            -Line 13
 473            +Middle A
 474            +Middle B
 475             Line 14
 476             Line 15
 477             Line 16
 478        "},
 479        "After inserting above (should coalesce)"
 480    );
 481
 482    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 483    // This tests that coalescing considers the END of the existing range
 484    buffer.update(cx, |buffer, cx| {
 485        let offset = Point::new(14, 0).to_offset(buffer);
 486        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 487    });
 488
 489    let events = ep_store.update(cx, |ep_store, cx| {
 490        ep_store.edit_history_for_project(&project, cx)
 491    });
 492    assert_eq!(
 493        render_events(&events),
 494        indoc! {"
 495            @@ -3,15 +3,16 @@
 496             Line 3
 497             Line 4
 498             Line 5
 499            +Above
 500             Line 6
 501             Line 7
 502             Line 8
 503             Line 9
 504             Line 10
 505            -Line 11
 506            -Line 12
 507            -Line 13
 508            +Middle A
 509            +Middle B
 510             Line 14
 511            +Below
 512             Line 15
 513             Line 16
 514             Line 17
 515        "},
 516        "After inserting below (should coalesce)"
 517    );
 518
 519    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 520    // This should NOT coalesce - creates a new event
 521    buffer.update(cx, |buffer, cx| {
 522        let offset = Point::new(25, 0).to_offset(buffer);
 523        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 524    });
 525
 526    let events = ep_store.update(cx, |ep_store, cx| {
 527        ep_store.edit_history_for_project(&project, cx)
 528    });
 529    assert_eq!(
 530        render_events(&events),
 531        indoc! {"
 532            @@ -3,15 +3,16 @@
 533             Line 3
 534             Line 4
 535             Line 5
 536            +Above
 537             Line 6
 538             Line 7
 539             Line 8
 540             Line 9
 541             Line 10
 542            -Line 11
 543            -Line 12
 544            -Line 13
 545            +Middle A
 546            +Middle B
 547             Line 14
 548            +Below
 549             Line 15
 550             Line 16
 551             Line 17
 552
 553            ---
 554            @@ -23,6 +23,7 @@
 555             Line 22
 556             Line 23
 557             Line 24
 558            +Far below
 559             Line 25
 560             Line 26
 561             Line 27
 562        "},
 563        "After inserting far below (should NOT coalesce)"
 564    );
 565}
 566
 567fn render_events(events: &[StoredEvent]) -> String {
 568    events
 569        .iter()
 570        .map(|e| {
 571            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 572            diff.as_str()
 573        })
 574        .collect::<Vec<_>>()
 575        .join("\n---\n")
 576}
 577
 578fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
 579    events
 580        .iter()
 581        .map(|e| {
 582            let zeta_prompt::Event::BufferChange {
 583                diff, predicted, ..
 584            } = e.event.as_ref();
 585            let prefix = if *predicted { "predicted" } else { "manual" };
 586            format!("{}\n{}", prefix, diff)
 587        })
 588        .collect()
 589}
 590
 591#[gpui::test]
 592async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
 593    let (ep_store, _requests) = init_test_with_fake_client(cx);
 594    let fs = FakeFs::new(cx.executor());
 595    fs.insert_tree(
 596        "/root",
 597        json!({
 598            "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
 599        }),
 600    )
 601    .await;
 602    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 603
 604    let buffer = project
 605        .update(cx, |project, cx| {
 606            let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
 607            project.open_buffer(path, cx)
 608        })
 609        .await
 610        .unwrap();
 611
 612    ep_store.update(cx, |ep_store, cx| {
 613        ep_store.register_buffer(&buffer, &project, cx);
 614    });
 615
 616    // Case 1: Manual edits have `predicted` set to false.
 617    buffer.update(cx, |buffer, cx| {
 618        buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
 619    });
 620
 621    let events = ep_store.update(cx, |ep_store, cx| {
 622        ep_store.edit_history_for_project(&project, cx)
 623    });
 624
 625    assert_eq!(
 626        render_events_with_predicted(&events),
 627        vec![indoc! {"
 628            manual
 629            @@ -1,4 +1,4 @@
 630            -line 0
 631            +LINE ZERO
 632             line 1
 633             line 2
 634             line 3
 635        "}]
 636    );
 637
 638    // Case 2: Multiple successive manual edits near each other are merged into one
 639    // event with `predicted` set to false.
 640    buffer.update(cx, |buffer, cx| {
 641        let offset = Point::new(1, 0).to_offset(buffer);
 642        let end = Point::new(1, 6).to_offset(buffer);
 643        buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
 644    });
 645
 646    let events = ep_store.update(cx, |ep_store, cx| {
 647        ep_store.edit_history_for_project(&project, cx)
 648    });
 649    assert_eq!(
 650        render_events_with_predicted(&events),
 651        vec![indoc! {"
 652            manual
 653            @@ -1,5 +1,5 @@
 654            -line 0
 655            -line 1
 656            +LINE ZERO
 657            +LINE ONE
 658             line 2
 659             line 3
 660             line 4
 661        "}]
 662    );
 663
 664    // Case 3: Accepted predictions have `predicted` set to true.
 665    // Case 5: A manual edit that follows a predicted edit is not merged with the
 666    // predicted edit, even if it is nearby.
 667    ep_store.update(cx, |ep_store, cx| {
 668        buffer.update(cx, |buffer, cx| {
 669            let offset = Point::new(2, 0).to_offset(buffer);
 670            let end = Point::new(2, 6).to_offset(buffer);
 671            buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
 672        });
 673        ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
 674    });
 675
 676    let events = ep_store.update(cx, |ep_store, cx| {
 677        ep_store.edit_history_for_project(&project, cx)
 678    });
 679    assert_eq!(
 680        render_events_with_predicted(&events),
 681        vec![
 682            indoc! {"
 683                manual
 684                @@ -1,5 +1,5 @@
 685                -line 0
 686                -line 1
 687                +LINE ZERO
 688                +LINE ONE
 689                 line 2
 690                 line 3
 691                 line 4
 692            "},
 693            indoc! {"
 694                predicted
 695                @@ -1,6 +1,6 @@
 696                 LINE ZERO
 697                 LINE ONE
 698                -line 2
 699                +LINE TWO
 700                 line 3
 701                 line 4
 702                 line 5
 703            "}
 704        ]
 705    );
 706
 707    // Case 4: Multiple successive accepted predictions near each other are merged
 708    // into one event with `predicted` set to true.
 709    ep_store.update(cx, |ep_store, cx| {
 710        buffer.update(cx, |buffer, cx| {
 711            let offset = Point::new(3, 0).to_offset(buffer);
 712            let end = Point::new(3, 6).to_offset(buffer);
 713            buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
 714        });
 715        ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
 716    });
 717
 718    let events = ep_store.update(cx, |ep_store, cx| {
 719        ep_store.edit_history_for_project(&project, cx)
 720    });
 721    assert_eq!(
 722        render_events_with_predicted(&events),
 723        vec![
 724            indoc! {"
 725                manual
 726                @@ -1,5 +1,5 @@
 727                -line 0
 728                -line 1
 729                +LINE ZERO
 730                +LINE ONE
 731                 line 2
 732                 line 3
 733                 line 4
 734            "},
 735            indoc! {"
 736                predicted
 737                @@ -1,7 +1,7 @@
 738                 LINE ZERO
 739                 LINE ONE
 740                -line 2
 741                -line 3
 742                +LINE TWO
 743                +LINE THREE
 744                 line 4
 745                 line 5
 746                 line 6
 747            "}
 748        ]
 749    );
 750
 751    // Case 5 (continued): A manual edit that follows a predicted edit is not merged
 752    // with the predicted edit, even if it is nearby.
 753    buffer.update(cx, |buffer, cx| {
 754        let offset = Point::new(4, 0).to_offset(buffer);
 755        let end = Point::new(4, 6).to_offset(buffer);
 756        buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
 757    });
 758
 759    let events = ep_store.update(cx, |ep_store, cx| {
 760        ep_store.edit_history_for_project(&project, cx)
 761    });
 762    assert_eq!(
 763        render_events_with_predicted(&events),
 764        vec![
 765            indoc! {"
 766                manual
 767                @@ -1,5 +1,5 @@
 768                -line 0
 769                -line 1
 770                +LINE ZERO
 771                +LINE ONE
 772                 line 2
 773                 line 3
 774                 line 4
 775            "},
 776            indoc! {"
 777                predicted
 778                @@ -1,7 +1,7 @@
 779                 LINE ZERO
 780                 LINE ONE
 781                -line 2
 782                -line 3
 783                +LINE TWO
 784                +LINE THREE
 785                 line 4
 786                 line 5
 787                 line 6
 788            "},
 789            indoc! {"
 790                manual
 791                @@ -2,7 +2,7 @@
 792                 LINE ONE
 793                 LINE TWO
 794                 LINE THREE
 795                -line 4
 796                +LINE FOUR
 797                 line 5
 798                 line 6
 799                 line 7
 800            "}
 801        ]
 802    );
 803
 804    // Case 6: If we then perform a manual edit at a *different* location (more than
 805    // 8 lines away), then the edits at the prior location can be merged with each
 806    // other, even if some are predicted and some are not. `predicted` means all
 807    // constituent edits were predicted.
 808    buffer.update(cx, |buffer, cx| {
 809        let offset = Point::new(14, 0).to_offset(buffer);
 810        let end = Point::new(14, 7).to_offset(buffer);
 811        buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
 812    });
 813
 814    let events = ep_store.update(cx, |ep_store, cx| {
 815        ep_store.edit_history_for_project(&project, cx)
 816    });
 817    assert_eq!(
 818        render_events_with_predicted(&events),
 819        vec![
 820            indoc! {"
 821                manual
 822                @@ -1,8 +1,8 @@
 823                -line 0
 824                -line 1
 825                -line 2
 826                -line 3
 827                -line 4
 828                +LINE ZERO
 829                +LINE ONE
 830                +LINE TWO
 831                +LINE THREE
 832                +LINE FOUR
 833                 line 5
 834                 line 6
 835                 line 7
 836            "},
 837            indoc! {"
 838                manual
 839                @@ -12,4 +12,4 @@
 840                 line 11
 841                 line 12
 842                 line 13
 843                -line 14
 844                +LINE FOURTEEN
 845            "}
 846        ]
 847    );
 848}
 849
 850#[gpui::test]
 851async fn test_empty_prediction(cx: &mut TestAppContext) {
 852    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 853    let fs = FakeFs::new(cx.executor());
 854    fs.insert_tree(
 855        "/root",
 856        json!({
 857            "foo.md":  "Hello!\nHow\nBye\n"
 858        }),
 859    )
 860    .await;
 861    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 862
 863    let buffer = project
 864        .update(cx, |project, cx| {
 865            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 866            project.open_buffer(path, cx)
 867        })
 868        .await
 869        .unwrap();
 870    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 871    let position = snapshot.anchor_before(language::Point::new(1, 3));
 872
 873    ep_store.update(cx, |ep_store, cx| {
 874        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 875    });
 876
 877    let (request, respond_tx) = requests.predict.next().await.unwrap();
 878    let response = model_response(&request, "");
 879    let id = response.request_id.clone();
 880    respond_tx.send(response).unwrap();
 881
 882    cx.run_until_parked();
 883
 884    ep_store.update(cx, |ep_store, cx| {
 885        assert!(
 886            ep_store
 887                .prediction_at(&buffer, None, &project, cx)
 888                .is_none()
 889        );
 890    });
 891
 892    // prediction is reported as rejected
 893    let (reject_request, _) = requests.reject.next().await.unwrap();
 894
 895    assert_eq!(
 896        &reject_request.rejections,
 897        &[EditPredictionRejection {
 898            request_id: id,
 899            reason: EditPredictionRejectReason::Empty,
 900            was_shown: false
 901        }]
 902    );
 903}
 904
 905#[gpui::test]
 906async fn test_interpolated_empty(cx: &mut TestAppContext) {
 907    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 908    let fs = FakeFs::new(cx.executor());
 909    fs.insert_tree(
 910        "/root",
 911        json!({
 912            "foo.md":  "Hello!\nHow\nBye\n"
 913        }),
 914    )
 915    .await;
 916    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 917
 918    let buffer = project
 919        .update(cx, |project, cx| {
 920            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 921            project.open_buffer(path, cx)
 922        })
 923        .await
 924        .unwrap();
 925    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 926    let position = snapshot.anchor_before(language::Point::new(1, 3));
 927
 928    ep_store.update(cx, |ep_store, cx| {
 929        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 930    });
 931
 932    let (request, respond_tx) = requests.predict.next().await.unwrap();
 933
 934    buffer.update(cx, |buffer, cx| {
 935        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 936    });
 937
 938    let response = model_response(&request, SIMPLE_DIFF);
 939    let id = response.request_id.clone();
 940    respond_tx.send(response).unwrap();
 941
 942    cx.run_until_parked();
 943
 944    ep_store.update(cx, |ep_store, cx| {
 945        assert!(
 946            ep_store
 947                .prediction_at(&buffer, None, &project, cx)
 948                .is_none()
 949        );
 950    });
 951
 952    // prediction is reported as rejected
 953    let (reject_request, _) = requests.reject.next().await.unwrap();
 954
 955    assert_eq!(
 956        &reject_request.rejections,
 957        &[EditPredictionRejection {
 958            request_id: id,
 959            reason: EditPredictionRejectReason::InterpolatedEmpty,
 960            was_shown: false
 961        }]
 962    );
 963}
 964
 965const SIMPLE_DIFF: &str = indoc! { r"
 966    --- a/root/foo.md
 967    +++ b/root/foo.md
 968    @@ ... @@
 969     Hello!
 970    -How
 971    +How are you?
 972     Bye
 973"};
 974
 975#[gpui::test]
 976async fn test_replace_current(cx: &mut TestAppContext) {
 977    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 978    let fs = FakeFs::new(cx.executor());
 979    fs.insert_tree(
 980        "/root",
 981        json!({
 982            "foo.md":  "Hello!\nHow\nBye\n"
 983        }),
 984    )
 985    .await;
 986    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 987
 988    let buffer = project
 989        .update(cx, |project, cx| {
 990            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 991            project.open_buffer(path, cx)
 992        })
 993        .await
 994        .unwrap();
 995    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 996    let position = snapshot.anchor_before(language::Point::new(1, 3));
 997
 998    ep_store.update(cx, |ep_store, cx| {
 999        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1000    });
1001
1002    let (request, respond_tx) = requests.predict.next().await.unwrap();
1003    let first_response = model_response(&request, SIMPLE_DIFF);
1004    let first_id = first_response.request_id.clone();
1005    respond_tx.send(first_response).unwrap();
1006
1007    cx.run_until_parked();
1008
1009    ep_store.update(cx, |ep_store, cx| {
1010        assert_eq!(
1011            ep_store
1012                .prediction_at(&buffer, None, &project, cx)
1013                .unwrap()
1014                .id
1015                .0,
1016            first_id
1017        );
1018    });
1019
1020    // a second request is triggered
1021    ep_store.update(cx, |ep_store, cx| {
1022        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1023    });
1024
1025    let (request, respond_tx) = requests.predict.next().await.unwrap();
1026    let second_response = model_response(&request, SIMPLE_DIFF);
1027    let second_id = second_response.request_id.clone();
1028    respond_tx.send(second_response).unwrap();
1029
1030    cx.run_until_parked();
1031
1032    ep_store.update(cx, |ep_store, cx| {
1033        // second replaces first
1034        assert_eq!(
1035            ep_store
1036                .prediction_at(&buffer, None, &project, cx)
1037                .unwrap()
1038                .id
1039                .0,
1040            second_id
1041        );
1042    });
1043
1044    // first is reported as replaced
1045    let (reject_request, _) = requests.reject.next().await.unwrap();
1046
1047    assert_eq!(
1048        &reject_request.rejections,
1049        &[EditPredictionRejection {
1050            request_id: first_id,
1051            reason: EditPredictionRejectReason::Replaced,
1052            was_shown: false
1053        }]
1054    );
1055}
1056
1057#[gpui::test]
1058async fn test_current_preferred(cx: &mut TestAppContext) {
1059    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1060    let fs = FakeFs::new(cx.executor());
1061    fs.insert_tree(
1062        "/root",
1063        json!({
1064            "foo.md":  "Hello!\nHow\nBye\n"
1065        }),
1066    )
1067    .await;
1068    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1069
1070    let buffer = project
1071        .update(cx, |project, cx| {
1072            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1073            project.open_buffer(path, cx)
1074        })
1075        .await
1076        .unwrap();
1077    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1078    let position = snapshot.anchor_before(language::Point::new(1, 3));
1079
1080    ep_store.update(cx, |ep_store, cx| {
1081        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1082    });
1083
1084    let (request, respond_tx) = requests.predict.next().await.unwrap();
1085    let first_response = model_response(&request, SIMPLE_DIFF);
1086    let first_id = first_response.request_id.clone();
1087    respond_tx.send(first_response).unwrap();
1088
1089    cx.run_until_parked();
1090
1091    ep_store.update(cx, |ep_store, cx| {
1092        assert_eq!(
1093            ep_store
1094                .prediction_at(&buffer, None, &project, cx)
1095                .unwrap()
1096                .id
1097                .0,
1098            first_id
1099        );
1100    });
1101
1102    // a second request is triggered
1103    ep_store.update(cx, |ep_store, cx| {
1104        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1105    });
1106
1107    let (request, respond_tx) = requests.predict.next().await.unwrap();
1108    // worse than current prediction
1109    let second_response = model_response(
1110        &request,
1111        indoc! { r"
1112            --- a/root/foo.md
1113            +++ b/root/foo.md
1114            @@ ... @@
1115             Hello!
1116            -How
1117            +How are
1118             Bye
1119        "},
1120    );
1121    let second_id = second_response.request_id.clone();
1122    respond_tx.send(second_response).unwrap();
1123
1124    cx.run_until_parked();
1125
1126    ep_store.update(cx, |ep_store, cx| {
1127        // first is preferred over second
1128        assert_eq!(
1129            ep_store
1130                .prediction_at(&buffer, None, &project, cx)
1131                .unwrap()
1132                .id
1133                .0,
1134            first_id
1135        );
1136    });
1137
1138    // second is reported as rejected
1139    let (reject_request, _) = requests.reject.next().await.unwrap();
1140
1141    assert_eq!(
1142        &reject_request.rejections,
1143        &[EditPredictionRejection {
1144            request_id: second_id,
1145            reason: EditPredictionRejectReason::CurrentPreferred,
1146            was_shown: false
1147        }]
1148    );
1149}
1150
1151#[gpui::test]
1152async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1153    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1154    let fs = FakeFs::new(cx.executor());
1155    fs.insert_tree(
1156        "/root",
1157        json!({
1158            "foo.md":  "Hello!\nHow\nBye\n"
1159        }),
1160    )
1161    .await;
1162    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1163
1164    let buffer = project
1165        .update(cx, |project, cx| {
1166            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1167            project.open_buffer(path, cx)
1168        })
1169        .await
1170        .unwrap();
1171    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1172    let position = snapshot.anchor_before(language::Point::new(1, 3));
1173
1174    // start two refresh tasks
1175    ep_store.update(cx, |ep_store, cx| {
1176        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1177    });
1178
1179    let (request1, respond_first) = requests.predict.next().await.unwrap();
1180
1181    ep_store.update(cx, |ep_store, cx| {
1182        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1183    });
1184
1185    let (request, respond_second) = requests.predict.next().await.unwrap();
1186
1187    // wait for throttle
1188    cx.run_until_parked();
1189
1190    // second responds first
1191    let second_response = model_response(&request, SIMPLE_DIFF);
1192    let second_id = second_response.request_id.clone();
1193    respond_second.send(second_response).unwrap();
1194
1195    cx.run_until_parked();
1196
1197    ep_store.update(cx, |ep_store, cx| {
1198        // current prediction is second
1199        assert_eq!(
1200            ep_store
1201                .prediction_at(&buffer, None, &project, cx)
1202                .unwrap()
1203                .id
1204                .0,
1205            second_id
1206        );
1207    });
1208
1209    let first_response = model_response(&request1, SIMPLE_DIFF);
1210    let first_id = first_response.request_id.clone();
1211    respond_first.send(first_response).unwrap();
1212
1213    cx.run_until_parked();
1214
1215    ep_store.update(cx, |ep_store, cx| {
1216        // current prediction is still second, since first was cancelled
1217        assert_eq!(
1218            ep_store
1219                .prediction_at(&buffer, None, &project, cx)
1220                .unwrap()
1221                .id
1222                .0,
1223            second_id
1224        );
1225    });
1226
1227    // first is reported as rejected
1228    let (reject_request, _) = requests.reject.next().await.unwrap();
1229
1230    cx.run_until_parked();
1231
1232    assert_eq!(
1233        &reject_request.rejections,
1234        &[EditPredictionRejection {
1235            request_id: first_id,
1236            reason: EditPredictionRejectReason::Canceled,
1237            was_shown: false
1238        }]
1239    );
1240}
1241
1242#[gpui::test]
1243async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1244    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1245    let fs = FakeFs::new(cx.executor());
1246    fs.insert_tree(
1247        "/root",
1248        json!({
1249            "foo.md":  "Hello!\nHow\nBye\n"
1250        }),
1251    )
1252    .await;
1253    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1254
1255    let buffer = project
1256        .update(cx, |project, cx| {
1257            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1258            project.open_buffer(path, cx)
1259        })
1260        .await
1261        .unwrap();
1262    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1263    let position = snapshot.anchor_before(language::Point::new(1, 3));
1264
1265    // start two refresh tasks
1266    ep_store.update(cx, |ep_store, cx| {
1267        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1268    });
1269
1270    let (request1, respond_first) = requests.predict.next().await.unwrap();
1271
1272    ep_store.update(cx, |ep_store, cx| {
1273        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1274    });
1275
1276    let (request2, respond_second) = requests.predict.next().await.unwrap();
1277
1278    // wait for throttle, so requests are sent
1279    cx.run_until_parked();
1280
1281    ep_store.update(cx, |ep_store, cx| {
1282        // start a third request
1283        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1284
1285        // 2 are pending, so 2nd is cancelled
1286        assert_eq!(
1287            ep_store
1288                .get_or_init_project(&project, cx)
1289                .cancelled_predictions
1290                .iter()
1291                .copied()
1292                .collect::<Vec<_>>(),
1293            [1]
1294        );
1295    });
1296
1297    // wait for throttle
1298    cx.run_until_parked();
1299
1300    let (request3, respond_third) = requests.predict.next().await.unwrap();
1301
1302    let first_response = model_response(&request1, SIMPLE_DIFF);
1303    let first_id = first_response.request_id.clone();
1304    respond_first.send(first_response).unwrap();
1305
1306    cx.run_until_parked();
1307
1308    ep_store.update(cx, |ep_store, cx| {
1309        // current prediction is first
1310        assert_eq!(
1311            ep_store
1312                .prediction_at(&buffer, None, &project, cx)
1313                .unwrap()
1314                .id
1315                .0,
1316            first_id
1317        );
1318    });
1319
1320    let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1321    let cancelled_id = cancelled_response.request_id.clone();
1322    respond_second.send(cancelled_response).unwrap();
1323
1324    cx.run_until_parked();
1325
1326    ep_store.update(cx, |ep_store, cx| {
1327        // current prediction is still first, since second was cancelled
1328        assert_eq!(
1329            ep_store
1330                .prediction_at(&buffer, None, &project, cx)
1331                .unwrap()
1332                .id
1333                .0,
1334            first_id
1335        );
1336    });
1337
1338    let third_response = model_response(&request3, SIMPLE_DIFF);
1339    let third_response_id = third_response.request_id.clone();
1340    respond_third.send(third_response).unwrap();
1341
1342    cx.run_until_parked();
1343
1344    ep_store.update(cx, |ep_store, cx| {
1345        // third completes and replaces first
1346        assert_eq!(
1347            ep_store
1348                .prediction_at(&buffer, None, &project, cx)
1349                .unwrap()
1350                .id
1351                .0,
1352            third_response_id
1353        );
1354    });
1355
1356    // second is reported as rejected
1357    let (reject_request, _) = requests.reject.next().await.unwrap();
1358
1359    cx.run_until_parked();
1360
1361    assert_eq!(
1362        &reject_request.rejections,
1363        &[
1364            EditPredictionRejection {
1365                request_id: cancelled_id,
1366                reason: EditPredictionRejectReason::Canceled,
1367                was_shown: false
1368            },
1369            EditPredictionRejection {
1370                request_id: first_id,
1371                reason: EditPredictionRejectReason::Replaced,
1372                was_shown: false
1373            }
1374        ]
1375    );
1376}
1377
1378#[gpui::test]
1379async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
1380    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1381
1382    let fs = FakeFs::new(cx.executor());
1383    fs.insert_tree(
1384        "/root",
1385        json!({
1386            "foo.md":  "Hello!\nHow\nBye\n",
1387            "bar.md": "Hola!\nComo\nAdios\n"
1388        }),
1389    )
1390    .await;
1391    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1392
1393    let buffer = project
1394        .update(cx, |project, cx| {
1395            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1396            project.set_active_path(Some(path.clone()), cx);
1397            project.open_buffer(path, cx)
1398        })
1399        .await
1400        .unwrap();
1401    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1402    let position = snapshot.anchor_before(language::Point::new(1, 3));
1403
1404    ep_store.update(cx, |ep_store, cx| {
1405        ep_store.register_project(&project, cx);
1406        ep_store.register_buffer(&buffer, &project, cx);
1407    });
1408
1409    // First edit request - no prior edit, so not throttled.
1410    ep_store.update(cx, |ep_store, cx| {
1411        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1412    });
1413    let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
1414    edit_response_tx.send(empty_response()).unwrap();
1415    cx.run_until_parked();
1416
1417    let diagnostic = lsp::Diagnostic {
1418        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1419        severity: Some(lsp::DiagnosticSeverity::ERROR),
1420        message: "Sentence is incomplete".to_string(),
1421        ..Default::default()
1422    };
1423
1424    // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
1425    project.update(cx, |project, cx| {
1426        project.lsp_store().update(cx, |lsp_store, cx| {
1427            lsp_store
1428                .update_diagnostics(
1429                    LanguageServerId(0),
1430                    lsp::PublishDiagnosticsParams {
1431                        uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
1432                        diagnostics: vec![diagnostic],
1433                        version: None,
1434                    },
1435                    None,
1436                    language::DiagnosticSourceKind::Pushed,
1437                    &[],
1438                    cx,
1439                )
1440                .unwrap();
1441        });
1442    });
1443    let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
1444    jump_response_tx.send(empty_response()).unwrap();
1445    cx.run_until_parked();
1446
1447    // Second edit request - should be throttled by the first edit.
1448    ep_store.update(cx, |ep_store, cx| {
1449        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1450    });
1451    assert_no_predict_request_ready(&mut requests.predict);
1452
1453    // Second jump request - should be throttled by the first jump.
1454    ep_store.update(cx, |ep_store, cx| {
1455        ep_store.refresh_prediction_from_diagnostics(
1456            project.clone(),
1457            DiagnosticSearchScope::Global,
1458            cx,
1459        );
1460    });
1461    assert_no_predict_request_ready(&mut requests.predict);
1462
1463    // Wait for both throttles to expire.
1464    cx.background_executor
1465        .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
1466    cx.background_executor.run_until_parked();
1467    cx.run_until_parked();
1468
1469    // Both requests should now go through.
1470    let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
1471    response_tx_1.send(empty_response()).unwrap();
1472    cx.run_until_parked();
1473
1474    let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
1475    response_tx_2.send(empty_response()).unwrap();
1476    cx.run_until_parked();
1477}
1478
1479#[gpui::test]
1480async fn test_rejections_flushing(cx: &mut TestAppContext) {
1481    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1482
1483    ep_store.update(cx, |ep_store, cx| {
1484        ep_store.reject_prediction(
1485            EditPredictionId("test-1".into()),
1486            EditPredictionRejectReason::Discarded,
1487            false,
1488            cx,
1489        );
1490        ep_store.reject_prediction(
1491            EditPredictionId("test-2".into()),
1492            EditPredictionRejectReason::Canceled,
1493            true,
1494            cx,
1495        );
1496    });
1497
1498    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1499    cx.run_until_parked();
1500
1501    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1502    respond_tx.send(()).unwrap();
1503
1504    // batched
1505    assert_eq!(reject_request.rejections.len(), 2);
1506    assert_eq!(
1507        reject_request.rejections[0],
1508        EditPredictionRejection {
1509            request_id: "test-1".to_string(),
1510            reason: EditPredictionRejectReason::Discarded,
1511            was_shown: false
1512        }
1513    );
1514    assert_eq!(
1515        reject_request.rejections[1],
1516        EditPredictionRejection {
1517            request_id: "test-2".to_string(),
1518            reason: EditPredictionRejectReason::Canceled,
1519            was_shown: true
1520        }
1521    );
1522
1523    // Reaching batch size limit sends without debounce
1524    ep_store.update(cx, |ep_store, cx| {
1525        for i in 0..70 {
1526            ep_store.reject_prediction(
1527                EditPredictionId(format!("batch-{}", i).into()),
1528                EditPredictionRejectReason::Discarded,
1529                false,
1530                cx,
1531            );
1532        }
1533    });
1534
1535    // First MAX/2 items are sent immediately
1536    cx.run_until_parked();
1537    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1538    respond_tx.send(()).unwrap();
1539
1540    assert_eq!(reject_request.rejections.len(), 50);
1541    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1542    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1543
1544    // Remaining items are debounced with the next batch
1545    cx.executor().advance_clock(Duration::from_secs(15));
1546    cx.run_until_parked();
1547
1548    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1549    respond_tx.send(()).unwrap();
1550
1551    assert_eq!(reject_request.rejections.len(), 20);
1552    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1553    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1554
1555    // Request failure
1556    ep_store.update(cx, |ep_store, cx| {
1557        ep_store.reject_prediction(
1558            EditPredictionId("retry-1".into()),
1559            EditPredictionRejectReason::Discarded,
1560            false,
1561            cx,
1562        );
1563    });
1564
1565    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1566    cx.run_until_parked();
1567
1568    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1569    assert_eq!(reject_request.rejections.len(), 1);
1570    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1571    // Simulate failure
1572    drop(_respond_tx);
1573
1574    // Add another rejection
1575    ep_store.update(cx, |ep_store, cx| {
1576        ep_store.reject_prediction(
1577            EditPredictionId("retry-2".into()),
1578            EditPredictionRejectReason::Discarded,
1579            false,
1580            cx,
1581        );
1582    });
1583
1584    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1585    cx.run_until_parked();
1586
1587    // Retry should include both the failed item and the new one
1588    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1589    respond_tx.send(()).unwrap();
1590
1591    assert_eq!(reject_request.rejections.len(), 2);
1592    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1593    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1594}
1595
1596// Skipped until we start including diagnostics in prompt
1597// #[gpui::test]
1598// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1599//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1600//     let fs = FakeFs::new(cx.executor());
1601//     fs.insert_tree(
1602//         "/root",
1603//         json!({
1604//             "foo.md": "Hello!\nBye"
1605//         }),
1606//     )
1607//     .await;
1608//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1609
1610//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1611//     let diagnostic = lsp::Diagnostic {
1612//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1613//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1614//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1615//         ..Default::default()
1616//     };
1617
1618//     project.update(cx, |project, cx| {
1619//         project.lsp_store().update(cx, |lsp_store, cx| {
1620//             // Create some diagnostics
1621//             lsp_store
1622//                 .update_diagnostics(
1623//                     LanguageServerId(0),
1624//                     lsp::PublishDiagnosticsParams {
1625//                         uri: path_to_buffer_uri.clone(),
1626//                         diagnostics: vec![diagnostic],
1627//                         version: None,
1628//                     },
1629//                     None,
1630//                     language::DiagnosticSourceKind::Pushed,
1631//                     &[],
1632//                     cx,
1633//                 )
1634//                 .unwrap();
1635//         });
1636//     });
1637
1638//     let buffer = project
1639//         .update(cx, |project, cx| {
1640//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1641//             project.open_buffer(path, cx)
1642//         })
1643//         .await
1644//         .unwrap();
1645
1646//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1647//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1648
1649//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1650//         ep_store.request_prediction(&project, &buffer, position, cx)
1651//     });
1652
1653//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1654
1655//     assert_eq!(request.diagnostic_groups.len(), 1);
1656//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1657//         .unwrap();
1658//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1659//     assert_eq!(
1660//         value,
1661//         json!({
1662//             "entries": [{
1663//                 "range": {
1664//                     "start": 8,
1665//                     "end": 10
1666//                 },
1667//                 "diagnostic": {
1668//                     "source": null,
1669//                     "code": null,
1670//                     "code_description": null,
1671//                     "severity": 1,
1672//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1673//                     "markdown": null,
1674//                     "group_id": 0,
1675//                     "is_primary": true,
1676//                     "is_disk_based": false,
1677//                     "is_unnecessary": false,
1678//                     "source_kind": "Pushed",
1679//                     "data": null,
1680//                     "underline": true
1681//                 }
1682//             }],
1683//             "primary_ix": 0
1684//         })
1685//     );
1686// }
1687
1688// Generate a model response that would apply the given diff to the active file.
1689fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1690    let excerpt =
1691        request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1692    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1693
1694    PredictEditsV3Response {
1695        request_id: Uuid::new_v4().to_string(),
1696        output: new_excerpt,
1697    }
1698}
1699
1700fn empty_response() -> PredictEditsV3Response {
1701    PredictEditsV3Response {
1702        request_id: Uuid::new_v4().to_string(),
1703        output: String::new(),
1704    }
1705}
1706
1707fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1708    zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1709}
1710
1711fn assert_no_predict_request_ready(
1712    requests: &mut mpsc::UnboundedReceiver<(
1713        PredictEditsV3Request,
1714        oneshot::Sender<PredictEditsV3Response>,
1715    )>,
1716) {
1717    if requests.next().now_or_never().flatten().is_some() {
1718        panic!("Unexpected prediction request while throttled.");
1719    }
1720}
1721
1722struct RequestChannels {
1723    predict: mpsc::UnboundedReceiver<(
1724        PredictEditsV3Request,
1725        oneshot::Sender<PredictEditsV3Response>,
1726    )>,
1727    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1728}
1729
1730fn init_test_with_fake_client(
1731    cx: &mut TestAppContext,
1732) -> (Entity<EditPredictionStore>, RequestChannels) {
1733    cx.update(move |cx| {
1734        let settings_store = SettingsStore::test(cx);
1735        cx.set_global(settings_store);
1736        zlog::init_test();
1737
1738        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1739        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1740
1741        let http_client = FakeHttpClient::create({
1742            move |req| {
1743                let uri = req.uri().path().to_string();
1744                let mut body = req.into_body();
1745                let predict_req_tx = predict_req_tx.clone();
1746                let reject_req_tx = reject_req_tx.clone();
1747                async move {
1748                    let resp = match uri.as_str() {
1749                        "/client/llm_tokens" => serde_json::to_string(&json!({
1750                            "token": "test"
1751                        }))
1752                        .unwrap(),
1753                        "/predict_edits/v3" => {
1754                            let mut buf = Vec::new();
1755                            body.read_to_end(&mut buf).await.ok();
1756                            let decompressed = zstd::decode_all(&buf[..]).unwrap();
1757                            let req = serde_json::from_slice(&decompressed).unwrap();
1758
1759                            let (res_tx, res_rx) = oneshot::channel();
1760                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1761                            serde_json::to_string(&res_rx.await?).unwrap()
1762                        }
1763                        "/predict_edits/reject" => {
1764                            let mut buf = Vec::new();
1765                            body.read_to_end(&mut buf).await.ok();
1766                            let req = serde_json::from_slice(&buf).unwrap();
1767
1768                            let (res_tx, res_rx) = oneshot::channel();
1769                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1770                            serde_json::to_string(&res_rx.await?).unwrap()
1771                        }
1772                        _ => {
1773                            panic!("Unexpected path: {}", uri)
1774                        }
1775                    };
1776
1777                    Ok(Response::builder().body(resp.into()).unwrap())
1778                }
1779            }
1780        });
1781
1782        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1783        client.cloud_client().set_credentials(1, "test".into());
1784
1785        language_model::init(client.clone(), cx);
1786
1787        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1788        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1789
1790        (
1791            ep_store,
1792            RequestChannels {
1793                predict: predict_req_rx,
1794                reject: reject_req_rx,
1795            },
1796        )
1797    })
1798}
1799
1800#[gpui::test]
1801async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1802    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1803    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1804        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1805    });
1806
1807    let edit_preview = cx
1808        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1809        .await;
1810
1811    let prediction = EditPrediction {
1812        edits,
1813        cursor_position: None,
1814        edit_preview,
1815        buffer: buffer.clone(),
1816        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1817        id: EditPredictionId("the-id".into()),
1818        inputs: ZetaPromptInput {
1819            events: Default::default(),
1820            related_files: Default::default(),
1821            cursor_path: Path::new("").into(),
1822            cursor_excerpt: "".into(),
1823            editable_range_in_excerpt: 0..0,
1824            cursor_offset_in_excerpt: 0,
1825            excerpt_start_row: None,
1826            excerpt_ranges: None,
1827            preferred_model: None,
1828            in_open_source_repo: false,
1829            can_collect_data: false,
1830        },
1831        buffer_snapshotted_at: Instant::now(),
1832        response_received_at: Instant::now(),
1833    };
1834
1835    cx.update(|cx| {
1836        assert_eq!(
1837            from_completion_edits(
1838                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1839                &buffer,
1840                cx
1841            ),
1842            vec![(2..5, "REM".into()), (9..11, "".into())]
1843        );
1844
1845        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1846        assert_eq!(
1847            from_completion_edits(
1848                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1849                &buffer,
1850                cx
1851            ),
1852            vec![(2..2, "REM".into()), (6..8, "".into())]
1853        );
1854
1855        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1856        assert_eq!(
1857            from_completion_edits(
1858                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1859                &buffer,
1860                cx
1861            ),
1862            vec![(2..5, "REM".into()), (9..11, "".into())]
1863        );
1864
1865        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1866        assert_eq!(
1867            from_completion_edits(
1868                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1869                &buffer,
1870                cx
1871            ),
1872            vec![(3..3, "EM".into()), (7..9, "".into())]
1873        );
1874
1875        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1876        assert_eq!(
1877            from_completion_edits(
1878                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1879                &buffer,
1880                cx
1881            ),
1882            vec![(4..4, "M".into()), (8..10, "".into())]
1883        );
1884
1885        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1886        assert_eq!(
1887            from_completion_edits(
1888                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1889                &buffer,
1890                cx
1891            ),
1892            vec![(9..11, "".into())]
1893        );
1894
1895        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1896        assert_eq!(
1897            from_completion_edits(
1898                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1899                &buffer,
1900                cx
1901            ),
1902            vec![(4..4, "M".into()), (8..10, "".into())]
1903        );
1904
1905        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1906        assert_eq!(
1907            from_completion_edits(
1908                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1909                &buffer,
1910                cx
1911            ),
1912            vec![(4..4, "M".into())]
1913        );
1914
1915        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1916        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1917    })
1918}
1919
1920#[gpui::test]
1921async fn test_clean_up_diff(cx: &mut TestAppContext) {
1922    init_test(cx);
1923
1924    assert_eq!(
1925        apply_edit_prediction(
1926            indoc! {"
1927                    fn main() {
1928                        let word_1 = \"lorem\";
1929                        let range = word.len()..word.len();
1930                    }
1931                "},
1932            indoc! {"
1933                    fn main() {
1934                        let word_1 = \"lorem\";
1935                        let range = word_1.len()..word_1.len();
1936                    }
1937                "},
1938            cx,
1939        )
1940        .await,
1941        indoc! {"
1942                fn main() {
1943                    let word_1 = \"lorem\";
1944                    let range = word_1.len()..word_1.len();
1945                }
1946            "},
1947    );
1948
1949    assert_eq!(
1950        apply_edit_prediction(
1951            indoc! {"
1952                    fn main() {
1953                        let story = \"the quick\"
1954                    }
1955                "},
1956            indoc! {"
1957                    fn main() {
1958                        let story = \"the quick brown fox jumps over the lazy dog\";
1959                    }
1960                "},
1961            cx,
1962        )
1963        .await,
1964        indoc! {"
1965                fn main() {
1966                    let story = \"the quick brown fox jumps over the lazy dog\";
1967                }
1968            "},
1969    );
1970}
1971
1972#[gpui::test]
1973async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1974    init_test(cx);
1975
1976    let buffer_content = "lorem\n";
1977    let completion_response = "lorem\nipsum\n";
1978
1979    assert_eq!(
1980        apply_edit_prediction(buffer_content, completion_response, cx).await,
1981        "lorem\nipsum\n"
1982    );
1983}
1984
1985#[gpui::test]
1986async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1987    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1988    // When the buffer ends without a trailing newline, but the model returns output
1989    // with a trailing newline, zeta2 should normalize both sides before diffing
1990    // so no spurious newline is inserted.
1991    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1992    let fs = FakeFs::new(cx.executor());
1993
1994    // Single line buffer with no trailing newline
1995    fs.insert_tree(
1996        "/root",
1997        json!({
1998            "foo.txt": "hello"
1999        }),
2000    )
2001    .await;
2002    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2003
2004    let buffer = project
2005        .update(cx, |project, cx| {
2006            let path = project
2007                .find_project_path(path!("root/foo.txt"), cx)
2008                .unwrap();
2009            project.open_buffer(path, cx)
2010        })
2011        .await
2012        .unwrap();
2013
2014    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2015    let position = snapshot.anchor_before(language::Point::new(0, 5));
2016
2017    ep_store.update(cx, |ep_store, cx| {
2018        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2019    });
2020
2021    let (_request, respond_tx) = requests.predict.next().await.unwrap();
2022
2023    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
2024    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
2025    let response = PredictEditsV3Response {
2026        request_id: Uuid::new_v4().to_string(),
2027        output: "hello world\n".to_string(),
2028    };
2029    respond_tx.send(response).unwrap();
2030
2031    cx.run_until_parked();
2032
2033    // The prediction should insert " world" without adding a newline
2034    ep_store.update(cx, |ep_store, cx| {
2035        let prediction = ep_store
2036            .prediction_at(&buffer, None, &project, cx)
2037            .expect("should have prediction");
2038        let edits: Vec<_> = prediction
2039            .edits
2040            .iter()
2041            .map(|(range, text)| {
2042                let snapshot = buffer.read(cx).snapshot();
2043                (range.to_offset(&snapshot), text.clone())
2044            })
2045            .collect();
2046        assert_eq!(edits, vec![(5..5, " world".into())]);
2047    });
2048}
2049
2050fn init_test(cx: &mut TestAppContext) {
2051    cx.update(|cx| {
2052        let settings_store = SettingsStore::test(cx);
2053        cx.set_global(settings_store);
2054    });
2055}
2056
2057async fn apply_edit_prediction(
2058    buffer_content: &str,
2059    completion_response: &str,
2060    cx: &mut TestAppContext,
2061) -> String {
2062    let fs = project::FakeFs::new(cx.executor());
2063    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2064    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2065    let (ep_store, response) = make_test_ep_store(&project, cx).await;
2066    *response.lock() = completion_response.to_string();
2067    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2068    buffer.update(cx, |buffer, cx| {
2069        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2070    });
2071    buffer.read_with(cx, |buffer, _| buffer.text())
2072}
2073
2074async fn run_edit_prediction(
2075    buffer: &Entity<Buffer>,
2076    project: &Entity<Project>,
2077    ep_store: &Entity<EditPredictionStore>,
2078    cx: &mut TestAppContext,
2079) -> EditPrediction {
2080    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2081    ep_store.update(cx, |ep_store, cx| {
2082        ep_store.register_buffer(buffer, &project, cx)
2083    });
2084    cx.background_executor.run_until_parked();
2085    let prediction_task = ep_store.update(cx, |ep_store, cx| {
2086        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2087    });
2088    prediction_task.await.unwrap().unwrap().prediction.unwrap()
2089}
2090
2091async fn make_test_ep_store(
2092    project: &Entity<Project>,
2093    cx: &mut TestAppContext,
2094) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
2095    let default_response = "hello world\n".to_string();
2096    let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
2097    let http_client = FakeHttpClient::create({
2098        let completion_response = completion_response.clone();
2099        let mut next_request_id = 0;
2100        move |req| {
2101            let completion_response = completion_response.clone();
2102            async move {
2103                match (req.method(), req.uri().path()) {
2104                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2105                        .status(200)
2106                        .body(
2107                            serde_json::to_string(&CreateLlmTokenResponse {
2108                                token: LlmToken("the-llm-token".to_string()),
2109                            })
2110                            .unwrap()
2111                            .into(),
2112                        )
2113                        .unwrap()),
2114                    (&Method::POST, "/predict_edits/v3") => {
2115                        next_request_id += 1;
2116                        Ok(http_client::Response::builder()
2117                            .status(200)
2118                            .body(
2119                                serde_json::to_string(&PredictEditsV3Response {
2120                                    request_id: format!("request-{next_request_id}"),
2121                                    output: completion_response.lock().clone(),
2122                                })
2123                                .unwrap()
2124                                .into(),
2125                            )
2126                            .unwrap())
2127                    }
2128                    _ => Ok(http_client::Response::builder()
2129                        .status(404)
2130                        .body("Not Found".into())
2131                        .unwrap()),
2132                }
2133            }
2134        }
2135    });
2136
2137    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2138    cx.update(|cx| {
2139        RefreshLlmTokenListener::register(client.clone(), cx);
2140    });
2141    let _server = FakeServer::for_client(42, &client, cx).await;
2142
2143    let ep_store = cx.new(|cx| {
2144        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2145        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2146
2147        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2148        for worktree in worktrees {
2149            let worktree_id = worktree.read(cx).id();
2150            ep_store
2151                .get_or_init_project(project, cx)
2152                .license_detection_watchers
2153                .entry(worktree_id)
2154                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2155        }
2156
2157        ep_store
2158    });
2159
2160    (ep_store, completion_response)
2161}
2162
2163fn to_completion_edits(
2164    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2165    buffer: &Entity<Buffer>,
2166    cx: &App,
2167) -> Vec<(Range<Anchor>, Arc<str>)> {
2168    let buffer = buffer.read(cx);
2169    iterator
2170        .into_iter()
2171        .map(|(range, text)| {
2172            (
2173                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2174                text,
2175            )
2176        })
2177        .collect()
2178}
2179
2180fn from_completion_edits(
2181    editor_edits: &[(Range<Anchor>, Arc<str>)],
2182    buffer: &Entity<Buffer>,
2183    cx: &App,
2184) -> Vec<(Range<usize>, Arc<str>)> {
2185    let buffer = buffer.read(cx);
2186    editor_edits
2187        .iter()
2188        .map(|(range, text)| {
2189            (
2190                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2191                text.clone(),
2192            )
2193        })
2194        .collect()
2195}
2196
2197#[gpui::test]
2198async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2199    init_test(cx);
2200
2201    let fs = FakeFs::new(cx.executor());
2202    fs.insert_tree(
2203        "/project",
2204        serde_json::json!({
2205            "main.rs": "fn main() {\n    \n}\n"
2206        }),
2207    )
2208    .await;
2209
2210    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2211
2212    let http_client = FakeHttpClient::create(|_req| async move {
2213        Ok(gpui::http_client::Response::builder()
2214            .status(401)
2215            .body("Unauthorized".into())
2216            .unwrap())
2217    });
2218
2219    let client =
2220        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2221    cx.update(|cx| {
2222        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2223    });
2224
2225    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2226
2227    let buffer = project
2228        .update(cx, |project, cx| {
2229            let path = project
2230                .find_project_path(path!("/project/main.rs"), cx)
2231                .unwrap();
2232            project.open_buffer(path, cx)
2233        })
2234        .await
2235        .unwrap();
2236
2237    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2238    ep_store.update(cx, |ep_store, cx| {
2239        ep_store.register_buffer(&buffer, &project, cx)
2240    });
2241    cx.background_executor.run_until_parked();
2242
2243    let completion_task = ep_store.update(cx, |ep_store, cx| {
2244        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2245        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2246    });
2247
2248    let result = completion_task.await;
2249    assert!(
2250        result.is_err(),
2251        "Without authentication and without custom URL, prediction should fail"
2252    );
2253}
2254
2255#[gpui::test]
2256fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2257    let buffer = cx.new(|cx| {
2258        Buffer::local(
2259            indoc! {"
2260                zero
2261                one
2262                two
2263                three
2264                four
2265                five
2266                six
2267                seven
2268                eight
2269                nine
2270                ten
2271                eleven
2272                twelve
2273                thirteen
2274                fourteen
2275                fifteen
2276                sixteen
2277                seventeen
2278                eighteen
2279                nineteen
2280                twenty
2281                twenty-one
2282                twenty-two
2283                twenty-three
2284                twenty-four
2285            "},
2286            cx,
2287        )
2288    });
2289
2290    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2291
2292    buffer.update(cx, |buffer, cx| {
2293        let point = Point::new(12, 0);
2294        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2295        let point = Point::new(8, 0);
2296        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2297    });
2298
2299    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2300
2301    let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2302
2303    assert_eq!(
2304        diff,
2305        indoc! {"
2306            @@ -6,10 +6,12 @@
2307             five
2308             six
2309             seven
2310            +FIRST INSERTION
2311             eight
2312             nine
2313             ten
2314             eleven
2315            +SECOND INSERTION
2316             twelve
2317             thirteen
2318             fourteen
2319            "}
2320    );
2321}
2322
2323#[gpui::test]
2324async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
2325    fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
2326        let collab_replica = clock::ReplicaId::new(10);
2327        let anchor = buffer.read_with(cx, |buffer, _| {
2328            buffer.snapshot().anchor_before(Point::new(row, 0))
2329        });
2330        let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
2331            id: 1,
2332            start: anchor,
2333            end: anchor,
2334            reversed: false,
2335            goal: SelectionGoal::None,
2336        }]);
2337        buffer.update(cx, |buffer, cx| {
2338            buffer.apply_ops(
2339                [Operation::UpdateSelections {
2340                    selections,
2341                    lamport_timestamp: clock::Lamport {
2342                        replica_id: collab_replica,
2343                        value: 1,
2344                    },
2345                    line_mode: false,
2346                    cursor_shape: CursorShape::Bar,
2347                }],
2348                cx,
2349            );
2350        });
2351    }
2352
2353    fn publish_diagnostics(
2354        uri_path: &'static str,
2355        rows: &[u32],
2356        project: &Entity<Project>,
2357        cx: &mut TestAppContext,
2358    ) {
2359        let diagnostics: Vec<_> = rows
2360            .iter()
2361            .map(|&row| lsp::Diagnostic {
2362                range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
2363                severity: Some(lsp::DiagnosticSeverity::ERROR),
2364                message: format!("error at row {row}"),
2365                ..Default::default()
2366            })
2367            .collect();
2368        project.update(cx, |project, cx| {
2369            project.lsp_store().update(cx, |lsp_store, cx| {
2370                lsp_store
2371                    .update_diagnostics(
2372                        LanguageServerId(0),
2373                        lsp::PublishDiagnosticsParams {
2374                            uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
2375                            diagnostics,
2376                            version: None,
2377                        },
2378                        None,
2379                        language::DiagnosticSourceKind::Pushed,
2380                        &[],
2381                        cx,
2382                    )
2383                    .expect("failed to update diagnostics");
2384            });
2385        });
2386    }
2387
2388    init_test(cx);
2389
2390    let mut lines = String::new();
2391    for i in 0..60 {
2392        lines.push_str(&format!("line {i}\n"));
2393    }
2394
2395    let fs = FakeFs::new(cx.executor());
2396    fs.insert_tree(
2397        "/root",
2398        json!({
2399            "active.txt": lines,
2400            "collab_file.txt": "error here\nsecond line\n",
2401            "free_file.txt": "another error\nsecond line\n",
2402        }),
2403    )
2404    .await;
2405    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2406
2407    let active_buffer = project
2408        .update(cx, |project, cx| {
2409            let path = project
2410                .find_project_path(path!("/root/active.txt"), cx)
2411                .expect("active.txt not found");
2412            project.set_active_path(Some(path.clone()), cx);
2413            project.open_buffer(path, cx)
2414        })
2415        .await
2416        .expect("failed to open active buffer");
2417
2418    set_collaborator_cursor(&active_buffer, 5, cx);
2419
2420    publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
2421
2422    cx.run_until_parked();
2423
2424    let cursor_point = Point::new(25, 0);
2425    let empty_search_range: Range<Point> = Default::default();
2426
2427    let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2428    let result = EditPredictionStore::next_diagnostic_location(
2429        active_buffer.clone(),
2430        &snapshot,
2431        empty_search_range.clone(),
2432        cursor_point,
2433        &project,
2434        &mut cx.to_async(),
2435    )
2436    .await
2437    .expect("next_diagnostic_location failed");
2438
2439    let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
2440    assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
2441    let result_row = result_buffer.read_with(cx, |buffer, _| {
2442        result_anchor.to_point(&buffer.snapshot()).row
2443    });
2444    assert_ne!(
2445        result_row, 3,
2446        "row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
2447    );
2448    assert!(
2449        result_row == 25 || result_row == 50,
2450        "expected row 25 or 50, got {result_row}"
2451    );
2452
2453    let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2454    let near_cursor_point = Point::new(4, 0);
2455    let result_near = EditPredictionStore::next_diagnostic_location(
2456        active_buffer.clone(),
2457        &snapshot_near,
2458        empty_search_range.clone(),
2459        near_cursor_point,
2460        &project,
2461        &mut cx.to_async(),
2462    )
2463    .await
2464    .expect("next_diagnostic_location failed");
2465
2466    let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
2467    let near_row =
2468        active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
2469    assert_eq!(
2470        near_row, 3,
2471        "row 3 should be included when local cursor (row 4) is also near the collaborator"
2472    );
2473
2474    let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2475    let far_cursor_point = Point::new(50, 0);
2476    let result_far = EditPredictionStore::next_diagnostic_location(
2477        active_buffer.clone(),
2478        &snapshot_far,
2479        empty_search_range.clone(),
2480        far_cursor_point,
2481        &project,
2482        &mut cx.to_async(),
2483    )
2484    .await
2485    .expect("next_diagnostic_location failed");
2486
2487    let (_, far_anchor) = result_far.expect("expected a diagnostic location");
2488    let far_row =
2489        active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
2490    assert_eq!(
2491        far_row, 50,
2492        "row 50 is near local cursor (row 50) and far from collaborator, should be picked"
2493    );
2494
2495    publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
2496    publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
2497    cx.run_until_parked();
2498
2499    let collab_buffer = project
2500        .update(cx, |project, cx| {
2501            let path = project
2502                .find_project_path(path!("/root/collab_file.txt"), cx)
2503                .expect("collab_file.txt not found");
2504            project.open_buffer(path, cx)
2505        })
2506        .await
2507        .expect("failed to open collab buffer");
2508
2509    set_collaborator_cursor(&collab_buffer, 0, cx);
2510    cx.run_until_parked();
2511
2512    let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
2513    let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2514    let result_cross = EditPredictionStore::next_diagnostic_location(
2515        active_buffer.clone(),
2516        &snapshot_cross,
2517        no_same_file_search_range,
2518        Point::new(0, 0),
2519        &project,
2520        &mut cx.to_async(),
2521    )
2522    .await
2523    .expect("cross-file next_diagnostic_location failed");
2524
2525    let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
2526    let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
2527        buffer
2528            .file()
2529            .expect("buffer should have a file")
2530            .full_path(cx)
2531    });
2532    assert_eq!(
2533        cross_path,
2534        Path::new(path!("root/free_file.txt")),
2535        "should skip collab_file.txt (has collaborator) and pick free_file.txt"
2536    );
2537}
2538
2539#[ctor::ctor]
2540fn init_logger() {
2541    zlog::init_test();
2542}