edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
   3use client::{UserStore, test::FakeServer};
   4use clock::FakeSystemClock;
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
   8    predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
   9};
  10use futures::{
  11    AsyncReadExt, FutureExt, StreamExt,
  12    channel::{mpsc, oneshot},
  13};
  14use gpui::App;
  15use gpui::{
  16    Entity, TestAppContext,
  17    http_client::{FakeHttpClient, Response},
  18};
  19use indoc::indoc;
  20use language::{Anchor, Buffer, CursorShape, Operation, Point, Selection, SelectionGoal};
  21use lsp::LanguageServerId;
  22use parking_lot::Mutex;
  23use pretty_assertions::{assert_eq, assert_matches};
  24use project::{FakeFs, Project};
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{path::Path, sync::Arc, time::Duration};
  28use util::path;
  29use uuid::Uuid;
  30use zeta_prompt::ZetaPromptInput;
  31
  32use crate::{
  33    BufferEditPrediction, EDIT_PREDICTION_SETTLED_QUIESCENCE, EditPredictionId,
  34    EditPredictionStore, REJECT_REQUEST_DEBOUNCE,
  35};
  36
  37#[gpui::test]
  38async fn test_current_state(cx: &mut TestAppContext) {
  39    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  40    let fs = FakeFs::new(cx.executor());
  41    fs.insert_tree(
  42        "/root",
  43        json!({
  44            "1.txt": "Hello!\nHow\nBye\n",
  45            "2.txt": "Hola!\nComo\nAdios\n"
  46        }),
  47    )
  48    .await;
  49    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  50
  51    let buffer1 = project
  52        .update(cx, |project, cx| {
  53            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  54            project.set_active_path(Some(path.clone()), cx);
  55            project.open_buffer(path, cx)
  56        })
  57        .await
  58        .unwrap();
  59    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  60    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  61
  62    ep_store.update(cx, |ep_store, cx| {
  63        ep_store.register_project(&project, cx);
  64        ep_store.register_buffer(&buffer1, &project, cx);
  65    });
  66
  67    // Prediction for current file
  68
  69    ep_store.update(cx, |ep_store, cx| {
  70        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  71    });
  72    let (request, respond_tx) = requests.predict.next().await.unwrap();
  73
  74    respond_tx
  75        .send(model_response(
  76            &request,
  77            indoc! {r"
  78                --- a/root/1.txt
  79                +++ b/root/1.txt
  80                @@ ... @@
  81                 Hello!
  82                -How
  83                +How are you?
  84                 Bye
  85            "},
  86        ))
  87        .unwrap();
  88
  89    cx.run_until_parked();
  90
  91    ep_store.update(cx, |ep_store, cx| {
  92        let prediction = ep_store
  93            .prediction_at(&buffer1, None, &project, cx)
  94            .unwrap();
  95        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  96    });
  97
  98    ep_store.update(cx, |ep_store, cx| {
  99        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
 100    });
 101
 102    // Prediction for diagnostic in another file
 103
 104    let diagnostic = lsp::Diagnostic {
 105        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 106        severity: Some(lsp::DiagnosticSeverity::ERROR),
 107        message: "Sentence is incomplete".to_string(),
 108        ..Default::default()
 109    };
 110
 111    project.update(cx, |project, cx| {
 112        project.lsp_store().update(cx, |lsp_store, cx| {
 113            lsp_store
 114                .update_diagnostics(
 115                    LanguageServerId(0),
 116                    lsp::PublishDiagnosticsParams {
 117                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 118                        diagnostics: vec![diagnostic],
 119                        version: None,
 120                    },
 121                    None,
 122                    language::DiagnosticSourceKind::Pushed,
 123                    &[],
 124                    cx,
 125                )
 126                .unwrap();
 127        });
 128    });
 129
 130    let (request, respond_tx) = requests.predict.next().await.unwrap();
 131    respond_tx
 132        .send(model_response(
 133            &request,
 134            indoc! {r#"
 135                --- a/root/2.txt
 136                +++ b/root/2.txt
 137                @@ ... @@
 138                 Hola!
 139                -Como
 140                +Como estas?
 141                 Adios
 142            "#},
 143        ))
 144        .unwrap();
 145    cx.run_until_parked();
 146
 147    ep_store.update(cx, |ep_store, cx| {
 148        let prediction = ep_store
 149            .prediction_at(&buffer1, None, &project, cx)
 150            .unwrap();
 151        assert_matches!(
 152            prediction,
 153            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 154        );
 155    });
 156
 157    let buffer2 = project
 158        .update(cx, |project, cx| {
 159            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 160            project.open_buffer(path, cx)
 161        })
 162        .await
 163        .unwrap();
 164
 165    ep_store.update(cx, |ep_store, cx| {
 166        let prediction = ep_store
 167            .prediction_at(&buffer2, None, &project, cx)
 168            .unwrap();
 169        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 170    });
 171}
 172
 173#[gpui::test]
 174async fn test_simple_request(cx: &mut TestAppContext) {
 175    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 176    let fs = FakeFs::new(cx.executor());
 177    fs.insert_tree(
 178        "/root",
 179        json!({
 180            "foo.md":  "Hello!\nHow\nBye\n"
 181        }),
 182    )
 183    .await;
 184    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 185
 186    let buffer = project
 187        .update(cx, |project, cx| {
 188            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 189            project.open_buffer(path, cx)
 190        })
 191        .await
 192        .unwrap();
 193    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 194    let position = snapshot.anchor_before(language::Point::new(1, 3));
 195
 196    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 197        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 198    });
 199
 200    let (request, respond_tx) = requests.predict.next().await.unwrap();
 201
 202    // TODO Put back when we have a structured request again
 203    // assert_eq!(
 204    //     request.excerpt_path.as_ref(),
 205    //     Path::new(path!("root/foo.md"))
 206    // );
 207    // assert_eq!(
 208    //     request.cursor_point,
 209    //     Point {
 210    //         line: Line(1),
 211    //         column: 3
 212    //     }
 213    // );
 214
 215    respond_tx
 216        .send(model_response(
 217            &request,
 218            indoc! { r"
 219                --- a/root/foo.md
 220                +++ b/root/foo.md
 221                @@ ... @@
 222                 Hello!
 223                -How
 224                +How are you?
 225                 Bye
 226            "},
 227        ))
 228        .unwrap();
 229
 230    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 231
 232    assert_eq!(prediction.edits.len(), 1);
 233    assert_eq!(
 234        prediction.edits[0].0.to_point(&snapshot).start,
 235        language::Point::new(1, 3)
 236    );
 237    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 238}
 239
 240#[gpui::test]
 241async fn test_request_events(cx: &mut TestAppContext) {
 242    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 243    let fs = FakeFs::new(cx.executor());
 244    fs.insert_tree(
 245        "/root",
 246        json!({
 247            "foo.md": "Hello!\n\nBye\n"
 248        }),
 249    )
 250    .await;
 251    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 252
 253    let buffer = project
 254        .update(cx, |project, cx| {
 255            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 256            project.open_buffer(path, cx)
 257        })
 258        .await
 259        .unwrap();
 260
 261    ep_store.update(cx, |ep_store, cx| {
 262        ep_store.register_buffer(&buffer, &project, cx);
 263    });
 264
 265    buffer.update(cx, |buffer, cx| {
 266        buffer.edit(vec![(7..7, "How")], None, cx);
 267    });
 268
 269    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 270    let position = snapshot.anchor_before(language::Point::new(1, 3));
 271
 272    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 273        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 274    });
 275
 276    let (request, respond_tx) = requests.predict.next().await.unwrap();
 277
 278    let prompt = prompt_from_request(&request);
 279    assert!(
 280        prompt.contains(indoc! {"
 281        --- a/root/foo.md
 282        +++ b/root/foo.md
 283        @@ -1,3 +1,3 @@
 284         Hello!
 285        -
 286        +How
 287         Bye
 288    "}),
 289        "{prompt}"
 290    );
 291
 292    respond_tx
 293        .send(model_response(
 294            &request,
 295            indoc! {r#"
 296                --- a/root/foo.md
 297                +++ b/root/foo.md
 298                @@ ... @@
 299                 Hello!
 300                -How
 301                +How are you?
 302                 Bye
 303        "#},
 304        ))
 305        .unwrap();
 306
 307    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 308
 309    assert_eq!(prediction.edits.len(), 1);
 310    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 311}
 312
 313#[gpui::test]
 314async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 315    let (ep_store, _requests) = init_test_with_fake_client(cx);
 316    let fs = FakeFs::new(cx.executor());
 317    fs.insert_tree(
 318        "/root",
 319        json!({
 320            "foo.md": "Hello!\n\nBye\n"
 321        }),
 322    )
 323    .await;
 324    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 325
 326    let buffer = project
 327        .update(cx, |project, cx| {
 328            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 329            project.open_buffer(path, cx)
 330        })
 331        .await
 332        .unwrap();
 333
 334    ep_store.update(cx, |ep_store, cx| {
 335        ep_store.register_buffer(&buffer, &project, cx);
 336    });
 337
 338    // First burst: insert "How"
 339    buffer.update(cx, |buffer, cx| {
 340        buffer.edit(vec![(7..7, "How")], None, cx);
 341    });
 342
 343    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 344    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 345    cx.run_until_parked();
 346
 347    // Second burst: append " are you?" immediately after "How" on the same line.
 348    //
 349    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 350    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 351    buffer.update(cx, |buffer, cx| {
 352        buffer.edit(vec![(10..10, " are you?")], None, cx);
 353    });
 354
 355    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 356    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 357    buffer.update(cx, |buffer, cx| {
 358        buffer.edit(vec![(19..19, "!")], None, cx);
 359    });
 360
 361    // With time-based splitting, there are two distinct events.
 362    let events = ep_store.update(cx, |ep_store, cx| {
 363        ep_store.edit_history_for_project(&project, cx)
 364    });
 365    assert_eq!(events.len(), 2);
 366    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 367    assert_eq!(
 368        diff.as_str(),
 369        indoc! {"
 370            @@ -1,3 +1,3 @@
 371             Hello!
 372            -
 373            +How
 374             Bye
 375        "}
 376    );
 377
 378    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 379    assert_eq!(
 380        diff.as_str(),
 381        indoc! {"
 382            @@ -1,3 +1,3 @@
 383             Hello!
 384            -How
 385            +How are you?!
 386             Bye
 387        "}
 388    );
 389}
 390
 391#[gpui::test]
 392async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
 393    let (ep_store, _requests) = init_test_with_fake_client(cx);
 394    let fs = FakeFs::new(cx.executor());
 395
 396    // Create a file with 30 lines to test line-based coalescing
 397    let content = (1..=30)
 398        .map(|i| format!("Line {}\n", i))
 399        .collect::<String>();
 400    fs.insert_tree(
 401        "/root",
 402        json!({
 403            "foo.md": content
 404        }),
 405    )
 406    .await;
 407    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 408
 409    let buffer = project
 410        .update(cx, |project, cx| {
 411            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 412            project.open_buffer(path, cx)
 413        })
 414        .await
 415        .unwrap();
 416
 417    ep_store.update(cx, |ep_store, cx| {
 418        ep_store.register_buffer(&buffer, &project, cx);
 419    });
 420
 421    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 422    buffer.update(cx, |buffer, cx| {
 423        let start = Point::new(10, 0).to_offset(buffer);
 424        let end = Point::new(13, 0).to_offset(buffer);
 425        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 426    });
 427
 428    let events = ep_store.update(cx, |ep_store, cx| {
 429        ep_store.edit_history_for_project(&project, cx)
 430    });
 431    assert_eq!(
 432        render_events(&events),
 433        indoc! {"
 434            @@ -8,9 +8,8 @@
 435             Line 8
 436             Line 9
 437             Line 10
 438            -Line 11
 439            -Line 12
 440            -Line 13
 441            +Middle A
 442            +Middle B
 443             Line 14
 444             Line 15
 445             Line 16
 446        "},
 447        "After first edit"
 448    );
 449
 450    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 451    // This tests that coalescing considers the START of the existing range
 452    buffer.update(cx, |buffer, cx| {
 453        let offset = Point::new(5, 0).to_offset(buffer);
 454        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 455    });
 456
 457    let events = ep_store.update(cx, |ep_store, cx| {
 458        ep_store.edit_history_for_project(&project, cx)
 459    });
 460    assert_eq!(
 461        render_events(&events),
 462        indoc! {"
 463            @@ -3,14 +3,14 @@
 464             Line 3
 465             Line 4
 466             Line 5
 467            +Above
 468             Line 6
 469             Line 7
 470             Line 8
 471             Line 9
 472             Line 10
 473            -Line 11
 474            -Line 12
 475            -Line 13
 476            +Middle A
 477            +Middle B
 478             Line 14
 479             Line 15
 480             Line 16
 481        "},
 482        "After inserting above (should coalesce)"
 483    );
 484
 485    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 486    // This tests that coalescing considers the END of the existing range
 487    buffer.update(cx, |buffer, cx| {
 488        let offset = Point::new(14, 0).to_offset(buffer);
 489        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 490    });
 491
 492    let events = ep_store.update(cx, |ep_store, cx| {
 493        ep_store.edit_history_for_project(&project, cx)
 494    });
 495    assert_eq!(
 496        render_events(&events),
 497        indoc! {"
 498            @@ -3,15 +3,16 @@
 499             Line 3
 500             Line 4
 501             Line 5
 502            +Above
 503             Line 6
 504             Line 7
 505             Line 8
 506             Line 9
 507             Line 10
 508            -Line 11
 509            -Line 12
 510            -Line 13
 511            +Middle A
 512            +Middle B
 513             Line 14
 514            +Below
 515             Line 15
 516             Line 16
 517             Line 17
 518        "},
 519        "After inserting below (should coalesce)"
 520    );
 521
 522    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 523    // This should NOT coalesce - creates a new event
 524    buffer.update(cx, |buffer, cx| {
 525        let offset = Point::new(25, 0).to_offset(buffer);
 526        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 527    });
 528
 529    let events = ep_store.update(cx, |ep_store, cx| {
 530        ep_store.edit_history_for_project(&project, cx)
 531    });
 532    assert_eq!(
 533        render_events(&events),
 534        indoc! {"
 535            @@ -3,15 +3,16 @@
 536             Line 3
 537             Line 4
 538             Line 5
 539            +Above
 540             Line 6
 541             Line 7
 542             Line 8
 543             Line 9
 544             Line 10
 545            -Line 11
 546            -Line 12
 547            -Line 13
 548            +Middle A
 549            +Middle B
 550             Line 14
 551            +Below
 552             Line 15
 553             Line 16
 554             Line 17
 555
 556            ---
 557            @@ -23,6 +23,7 @@
 558             Line 22
 559             Line 23
 560             Line 24
 561            +Far below
 562             Line 25
 563             Line 26
 564             Line 27
 565        "},
 566        "After inserting far below (should NOT coalesce)"
 567    );
 568}
 569
 570fn render_events(events: &[StoredEvent]) -> String {
 571    events
 572        .iter()
 573        .map(|e| {
 574            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 575            diff.as_str()
 576        })
 577        .collect::<Vec<_>>()
 578        .join("\n---\n")
 579}
 580
 581fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
 582    events
 583        .iter()
 584        .map(|e| {
 585            let zeta_prompt::Event::BufferChange {
 586                diff, predicted, ..
 587            } = e.event.as_ref();
 588            let prefix = if *predicted { "predicted" } else { "manual" };
 589            format!("{}\n{}", prefix, diff)
 590        })
 591        .collect()
 592}
 593
 594#[gpui::test]
 595async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
 596    let (ep_store, _requests) = init_test_with_fake_client(cx);
 597    let fs = FakeFs::new(cx.executor());
 598    fs.insert_tree(
 599        "/root",
 600        json!({
 601            "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
 602        }),
 603    )
 604    .await;
 605    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 606
 607    let buffer = project
 608        .update(cx, |project, cx| {
 609            let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
 610            project.open_buffer(path, cx)
 611        })
 612        .await
 613        .unwrap();
 614
 615    ep_store.update(cx, |ep_store, cx| {
 616        ep_store.register_buffer(&buffer, &project, cx);
 617    });
 618
 619    // Case 1: Manual edits have `predicted` set to false.
 620    buffer.update(cx, |buffer, cx| {
 621        buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
 622    });
 623
 624    let events = ep_store.update(cx, |ep_store, cx| {
 625        ep_store.edit_history_for_project(&project, cx)
 626    });
 627
 628    assert_eq!(
 629        render_events_with_predicted(&events),
 630        vec![indoc! {"
 631            manual
 632            @@ -1,4 +1,4 @@
 633            -line 0
 634            +LINE ZERO
 635             line 1
 636             line 2
 637             line 3
 638        "}]
 639    );
 640
 641    // Case 2: Multiple successive manual edits near each other are merged into one
 642    // event with `predicted` set to false.
 643    buffer.update(cx, |buffer, cx| {
 644        let offset = Point::new(1, 0).to_offset(buffer);
 645        let end = Point::new(1, 6).to_offset(buffer);
 646        buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
 647    });
 648
 649    let events = ep_store.update(cx, |ep_store, cx| {
 650        ep_store.edit_history_for_project(&project, cx)
 651    });
 652    assert_eq!(
 653        render_events_with_predicted(&events),
 654        vec![indoc! {"
 655            manual
 656            @@ -1,5 +1,5 @@
 657            -line 0
 658            -line 1
 659            +LINE ZERO
 660            +LINE ONE
 661             line 2
 662             line 3
 663             line 4
 664        "}]
 665    );
 666
 667    // Case 3: Accepted predictions have `predicted` set to true.
 668    // Case 5: A manual edit that follows a predicted edit is not merged with the
 669    // predicted edit, even if it is nearby.
 670    ep_store.update(cx, |ep_store, cx| {
 671        buffer.update(cx, |buffer, cx| {
 672            let offset = Point::new(2, 0).to_offset(buffer);
 673            let end = Point::new(2, 6).to_offset(buffer);
 674            buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
 675        });
 676        ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
 677    });
 678
 679    let events = ep_store.update(cx, |ep_store, cx| {
 680        ep_store.edit_history_for_project(&project, cx)
 681    });
 682    assert_eq!(
 683        render_events_with_predicted(&events),
 684        vec![
 685            indoc! {"
 686                manual
 687                @@ -1,5 +1,5 @@
 688                -line 0
 689                -line 1
 690                +LINE ZERO
 691                +LINE ONE
 692                 line 2
 693                 line 3
 694                 line 4
 695            "},
 696            indoc! {"
 697                predicted
 698                @@ -1,6 +1,6 @@
 699                 LINE ZERO
 700                 LINE ONE
 701                -line 2
 702                +LINE TWO
 703                 line 3
 704                 line 4
 705                 line 5
 706            "}
 707        ]
 708    );
 709
 710    // Case 4: Multiple successive accepted predictions near each other are merged
 711    // into one event with `predicted` set to true.
 712    ep_store.update(cx, |ep_store, cx| {
 713        buffer.update(cx, |buffer, cx| {
 714            let offset = Point::new(3, 0).to_offset(buffer);
 715            let end = Point::new(3, 6).to_offset(buffer);
 716            buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
 717        });
 718        ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
 719    });
 720
 721    let events = ep_store.update(cx, |ep_store, cx| {
 722        ep_store.edit_history_for_project(&project, cx)
 723    });
 724    assert_eq!(
 725        render_events_with_predicted(&events),
 726        vec![
 727            indoc! {"
 728                manual
 729                @@ -1,5 +1,5 @@
 730                -line 0
 731                -line 1
 732                +LINE ZERO
 733                +LINE ONE
 734                 line 2
 735                 line 3
 736                 line 4
 737            "},
 738            indoc! {"
 739                predicted
 740                @@ -1,7 +1,7 @@
 741                 LINE ZERO
 742                 LINE ONE
 743                -line 2
 744                -line 3
 745                +LINE TWO
 746                +LINE THREE
 747                 line 4
 748                 line 5
 749                 line 6
 750            "}
 751        ]
 752    );
 753
 754    // Case 5 (continued): A manual edit that follows a predicted edit is not merged
 755    // with the predicted edit, even if it is nearby.
 756    buffer.update(cx, |buffer, cx| {
 757        let offset = Point::new(4, 0).to_offset(buffer);
 758        let end = Point::new(4, 6).to_offset(buffer);
 759        buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
 760    });
 761
 762    let events = ep_store.update(cx, |ep_store, cx| {
 763        ep_store.edit_history_for_project(&project, cx)
 764    });
 765    assert_eq!(
 766        render_events_with_predicted(&events),
 767        vec![
 768            indoc! {"
 769                manual
 770                @@ -1,5 +1,5 @@
 771                -line 0
 772                -line 1
 773                +LINE ZERO
 774                +LINE ONE
 775                 line 2
 776                 line 3
 777                 line 4
 778            "},
 779            indoc! {"
 780                predicted
 781                @@ -1,7 +1,7 @@
 782                 LINE ZERO
 783                 LINE ONE
 784                -line 2
 785                -line 3
 786                +LINE TWO
 787                +LINE THREE
 788                 line 4
 789                 line 5
 790                 line 6
 791            "},
 792            indoc! {"
 793                manual
 794                @@ -2,7 +2,7 @@
 795                 LINE ONE
 796                 LINE TWO
 797                 LINE THREE
 798                -line 4
 799                +LINE FOUR
 800                 line 5
 801                 line 6
 802                 line 7
 803            "}
 804        ]
 805    );
 806
 807    // Case 6: If we then perform a manual edit at a *different* location (more than
 808    // 8 lines away), then the edits at the prior location can be merged with each
 809    // other, even if some are predicted and some are not. `predicted` means all
 810    // constituent edits were predicted.
 811    buffer.update(cx, |buffer, cx| {
 812        let offset = Point::new(14, 0).to_offset(buffer);
 813        let end = Point::new(14, 7).to_offset(buffer);
 814        buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
 815    });
 816
 817    let events = ep_store.update(cx, |ep_store, cx| {
 818        ep_store.edit_history_for_project(&project, cx)
 819    });
 820    assert_eq!(
 821        render_events_with_predicted(&events),
 822        vec![
 823            indoc! {"
 824                manual
 825                @@ -1,8 +1,8 @@
 826                -line 0
 827                -line 1
 828                -line 2
 829                -line 3
 830                -line 4
 831                +LINE ZERO
 832                +LINE ONE
 833                +LINE TWO
 834                +LINE THREE
 835                +LINE FOUR
 836                 line 5
 837                 line 6
 838                 line 7
 839            "},
 840            indoc! {"
 841                manual
 842                @@ -12,4 +12,4 @@
 843                 line 11
 844                 line 12
 845                 line 13
 846                -line 14
 847                +LINE FOURTEEN
 848            "}
 849        ]
 850    );
 851}
 852
 853#[gpui::test]
 854async fn test_empty_prediction(cx: &mut TestAppContext) {
 855    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 856    let fs = FakeFs::new(cx.executor());
 857    fs.insert_tree(
 858        "/root",
 859        json!({
 860            "foo.md":  "Hello!\nHow\nBye\n"
 861        }),
 862    )
 863    .await;
 864    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 865
 866    let buffer = project
 867        .update(cx, |project, cx| {
 868            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 869            project.open_buffer(path, cx)
 870        })
 871        .await
 872        .unwrap();
 873    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 874    let position = snapshot.anchor_before(language::Point::new(1, 3));
 875
 876    ep_store.update(cx, |ep_store, cx| {
 877        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 878    });
 879
 880    let (request, respond_tx) = requests.predict.next().await.unwrap();
 881    let response = model_response(&request, "");
 882    let id = response.request_id.clone();
 883    respond_tx.send(response).unwrap();
 884
 885    cx.run_until_parked();
 886
 887    ep_store.update(cx, |ep_store, cx| {
 888        assert!(
 889            ep_store
 890                .prediction_at(&buffer, None, &project, cx)
 891                .is_none()
 892        );
 893    });
 894
 895    // prediction is reported as rejected
 896    let (reject_request, _) = requests.reject.next().await.unwrap();
 897
 898    assert_eq!(
 899        &reject_request.rejections,
 900        &[EditPredictionRejection {
 901            request_id: id,
 902            reason: EditPredictionRejectReason::Empty,
 903            was_shown: false,
 904            model_version: None,
 905        }]
 906    );
 907}
 908
 909#[gpui::test]
 910async fn test_interpolated_empty(cx: &mut TestAppContext) {
 911    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 912    let fs = FakeFs::new(cx.executor());
 913    fs.insert_tree(
 914        "/root",
 915        json!({
 916            "foo.md":  "Hello!\nHow\nBye\n"
 917        }),
 918    )
 919    .await;
 920    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 921
 922    let buffer = project
 923        .update(cx, |project, cx| {
 924            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 925            project.open_buffer(path, cx)
 926        })
 927        .await
 928        .unwrap();
 929    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 930    let position = snapshot.anchor_before(language::Point::new(1, 3));
 931
 932    ep_store.update(cx, |ep_store, cx| {
 933        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 934    });
 935
 936    let (request, respond_tx) = requests.predict.next().await.unwrap();
 937
 938    buffer.update(cx, |buffer, cx| {
 939        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 940    });
 941
 942    let response = model_response(&request, SIMPLE_DIFF);
 943    let id = response.request_id.clone();
 944    respond_tx.send(response).unwrap();
 945
 946    cx.run_until_parked();
 947
 948    ep_store.update(cx, |ep_store, cx| {
 949        assert!(
 950            ep_store
 951                .prediction_at(&buffer, None, &project, cx)
 952                .is_none()
 953        );
 954    });
 955
 956    // prediction is reported as rejected
 957    let (reject_request, _) = requests.reject.next().await.unwrap();
 958
 959    assert_eq!(
 960        &reject_request.rejections,
 961        &[EditPredictionRejection {
 962            request_id: id,
 963            reason: EditPredictionRejectReason::InterpolatedEmpty,
 964            was_shown: false,
 965            model_version: None,
 966        }]
 967    );
 968}
 969
 970const SIMPLE_DIFF: &str = indoc! { r"
 971    --- a/root/foo.md
 972    +++ b/root/foo.md
 973    @@ ... @@
 974     Hello!
 975    -How
 976    +How are you?
 977     Bye
 978"};
 979
 980#[gpui::test]
 981async fn test_replace_current(cx: &mut TestAppContext) {
 982    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 983    let fs = FakeFs::new(cx.executor());
 984    fs.insert_tree(
 985        "/root",
 986        json!({
 987            "foo.md":  "Hello!\nHow\nBye\n"
 988        }),
 989    )
 990    .await;
 991    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 992
 993    let buffer = project
 994        .update(cx, |project, cx| {
 995            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 996            project.open_buffer(path, cx)
 997        })
 998        .await
 999        .unwrap();
1000    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1001    let position = snapshot.anchor_before(language::Point::new(1, 3));
1002
1003    ep_store.update(cx, |ep_store, cx| {
1004        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1005    });
1006
1007    let (request, respond_tx) = requests.predict.next().await.unwrap();
1008    let first_response = model_response(&request, SIMPLE_DIFF);
1009    let first_id = first_response.request_id.clone();
1010    respond_tx.send(first_response).unwrap();
1011
1012    cx.run_until_parked();
1013
1014    ep_store.update(cx, |ep_store, cx| {
1015        assert_eq!(
1016            ep_store
1017                .prediction_at(&buffer, None, &project, cx)
1018                .unwrap()
1019                .id
1020                .0,
1021            first_id
1022        );
1023    });
1024
1025    // a second request is triggered
1026    ep_store.update(cx, |ep_store, cx| {
1027        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1028    });
1029
1030    let (request, respond_tx) = requests.predict.next().await.unwrap();
1031    let second_response = model_response(&request, SIMPLE_DIFF);
1032    let second_id = second_response.request_id.clone();
1033    respond_tx.send(second_response).unwrap();
1034
1035    cx.run_until_parked();
1036
1037    ep_store.update(cx, |ep_store, cx| {
1038        // second replaces first
1039        assert_eq!(
1040            ep_store
1041                .prediction_at(&buffer, None, &project, cx)
1042                .unwrap()
1043                .id
1044                .0,
1045            second_id
1046        );
1047    });
1048
1049    // first is reported as replaced
1050    let (reject_request, _) = requests.reject.next().await.unwrap();
1051
1052    assert_eq!(
1053        &reject_request.rejections,
1054        &[EditPredictionRejection {
1055            request_id: first_id,
1056            reason: EditPredictionRejectReason::Replaced,
1057            was_shown: false,
1058            model_version: None,
1059        }]
1060    );
1061}
1062
1063#[gpui::test]
1064async fn test_current_preferred(cx: &mut TestAppContext) {
1065    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1066    let fs = FakeFs::new(cx.executor());
1067    fs.insert_tree(
1068        "/root",
1069        json!({
1070            "foo.md":  "Hello!\nHow\nBye\n"
1071        }),
1072    )
1073    .await;
1074    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1075
1076    let buffer = project
1077        .update(cx, |project, cx| {
1078            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1079            project.open_buffer(path, cx)
1080        })
1081        .await
1082        .unwrap();
1083    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1084    let position = snapshot.anchor_before(language::Point::new(1, 3));
1085
1086    ep_store.update(cx, |ep_store, cx| {
1087        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1088    });
1089
1090    let (request, respond_tx) = requests.predict.next().await.unwrap();
1091    let first_response = model_response(&request, SIMPLE_DIFF);
1092    let first_id = first_response.request_id.clone();
1093    respond_tx.send(first_response).unwrap();
1094
1095    cx.run_until_parked();
1096
1097    ep_store.update(cx, |ep_store, cx| {
1098        assert_eq!(
1099            ep_store
1100                .prediction_at(&buffer, None, &project, cx)
1101                .unwrap()
1102                .id
1103                .0,
1104            first_id
1105        );
1106    });
1107
1108    // a second request is triggered
1109    ep_store.update(cx, |ep_store, cx| {
1110        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1111    });
1112
1113    let (request, respond_tx) = requests.predict.next().await.unwrap();
1114    // worse than current prediction
1115    let second_response = model_response(
1116        &request,
1117        indoc! { r"
1118            --- a/root/foo.md
1119            +++ b/root/foo.md
1120            @@ ... @@
1121             Hello!
1122            -How
1123            +How are
1124             Bye
1125        "},
1126    );
1127    let second_id = second_response.request_id.clone();
1128    respond_tx.send(second_response).unwrap();
1129
1130    cx.run_until_parked();
1131
1132    ep_store.update(cx, |ep_store, cx| {
1133        // first is preferred over second
1134        assert_eq!(
1135            ep_store
1136                .prediction_at(&buffer, None, &project, cx)
1137                .unwrap()
1138                .id
1139                .0,
1140            first_id
1141        );
1142    });
1143
1144    // second is reported as rejected
1145    let (reject_request, _) = requests.reject.next().await.unwrap();
1146
1147    assert_eq!(
1148        &reject_request.rejections,
1149        &[EditPredictionRejection {
1150            request_id: second_id,
1151            reason: EditPredictionRejectReason::CurrentPreferred,
1152            was_shown: false,
1153            model_version: None,
1154        }]
1155    );
1156}
1157
1158#[gpui::test]
1159async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1160    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1161    let fs = FakeFs::new(cx.executor());
1162    fs.insert_tree(
1163        "/root",
1164        json!({
1165            "foo.md":  "Hello!\nHow\nBye\n"
1166        }),
1167    )
1168    .await;
1169    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1170
1171    let buffer = project
1172        .update(cx, |project, cx| {
1173            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1174            project.open_buffer(path, cx)
1175        })
1176        .await
1177        .unwrap();
1178    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1179    let position = snapshot.anchor_before(language::Point::new(1, 3));
1180
1181    // start two refresh tasks
1182    ep_store.update(cx, |ep_store, cx| {
1183        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1184    });
1185
1186    let (request1, respond_first) = requests.predict.next().await.unwrap();
1187
1188    ep_store.update(cx, |ep_store, cx| {
1189        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1190    });
1191
1192    let (request, respond_second) = requests.predict.next().await.unwrap();
1193
1194    // wait for throttle
1195    cx.run_until_parked();
1196
1197    // second responds first
1198    let second_response = model_response(&request, SIMPLE_DIFF);
1199    let second_id = second_response.request_id.clone();
1200    respond_second.send(second_response).unwrap();
1201
1202    cx.run_until_parked();
1203
1204    ep_store.update(cx, |ep_store, cx| {
1205        // current prediction is second
1206        assert_eq!(
1207            ep_store
1208                .prediction_at(&buffer, None, &project, cx)
1209                .unwrap()
1210                .id
1211                .0,
1212            second_id
1213        );
1214    });
1215
1216    let first_response = model_response(&request1, SIMPLE_DIFF);
1217    let first_id = first_response.request_id.clone();
1218    respond_first.send(first_response).unwrap();
1219
1220    cx.run_until_parked();
1221
1222    ep_store.update(cx, |ep_store, cx| {
1223        // current prediction is still second, since first was cancelled
1224        assert_eq!(
1225            ep_store
1226                .prediction_at(&buffer, None, &project, cx)
1227                .unwrap()
1228                .id
1229                .0,
1230            second_id
1231        );
1232    });
1233
1234    // first is reported as rejected
1235    let (reject_request, _) = requests.reject.next().await.unwrap();
1236
1237    cx.run_until_parked();
1238
1239    assert_eq!(
1240        &reject_request.rejections,
1241        &[EditPredictionRejection {
1242            request_id: first_id,
1243            reason: EditPredictionRejectReason::Canceled,
1244            was_shown: false,
1245            model_version: None,
1246        }]
1247    );
1248}
1249
1250#[gpui::test]
1251async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1252    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1253    let fs = FakeFs::new(cx.executor());
1254    fs.insert_tree(
1255        "/root",
1256        json!({
1257            "foo.md":  "Hello!\nHow\nBye\n"
1258        }),
1259    )
1260    .await;
1261    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1262
1263    let buffer = project
1264        .update(cx, |project, cx| {
1265            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1266            project.open_buffer(path, cx)
1267        })
1268        .await
1269        .unwrap();
1270    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1271    let position = snapshot.anchor_before(language::Point::new(1, 3));
1272
1273    // start two refresh tasks
1274    ep_store.update(cx, |ep_store, cx| {
1275        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1276    });
1277
1278    let (request1, respond_first) = requests.predict.next().await.unwrap();
1279
1280    ep_store.update(cx, |ep_store, cx| {
1281        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1282    });
1283
1284    let (request2, respond_second) = requests.predict.next().await.unwrap();
1285
1286    // wait for throttle, so requests are sent
1287    cx.run_until_parked();
1288
1289    ep_store.update(cx, |ep_store, cx| {
1290        // start a third request
1291        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1292
1293        // 2 are pending, so 2nd is cancelled
1294        assert_eq!(
1295            ep_store
1296                .get_or_init_project(&project, cx)
1297                .cancelled_predictions
1298                .iter()
1299                .copied()
1300                .collect::<Vec<_>>(),
1301            [1]
1302        );
1303    });
1304
1305    // wait for throttle
1306    cx.run_until_parked();
1307
1308    let (request3, respond_third) = requests.predict.next().await.unwrap();
1309
1310    let first_response = model_response(&request1, SIMPLE_DIFF);
1311    let first_id = first_response.request_id.clone();
1312    respond_first.send(first_response).unwrap();
1313
1314    cx.run_until_parked();
1315
1316    ep_store.update(cx, |ep_store, cx| {
1317        // current prediction is first
1318        assert_eq!(
1319            ep_store
1320                .prediction_at(&buffer, None, &project, cx)
1321                .unwrap()
1322                .id
1323                .0,
1324            first_id
1325        );
1326    });
1327
1328    let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1329    let cancelled_id = cancelled_response.request_id.clone();
1330    respond_second.send(cancelled_response).unwrap();
1331
1332    cx.run_until_parked();
1333
1334    ep_store.update(cx, |ep_store, cx| {
1335        // current prediction is still first, since second was cancelled
1336        assert_eq!(
1337            ep_store
1338                .prediction_at(&buffer, None, &project, cx)
1339                .unwrap()
1340                .id
1341                .0,
1342            first_id
1343        );
1344    });
1345
1346    let third_response = model_response(&request3, SIMPLE_DIFF);
1347    let third_response_id = third_response.request_id.clone();
1348    respond_third.send(third_response).unwrap();
1349
1350    cx.run_until_parked();
1351
1352    ep_store.update(cx, |ep_store, cx| {
1353        // third completes and replaces first
1354        assert_eq!(
1355            ep_store
1356                .prediction_at(&buffer, None, &project, cx)
1357                .unwrap()
1358                .id
1359                .0,
1360            third_response_id
1361        );
1362    });
1363
1364    // second is reported as rejected
1365    let (reject_request, _) = requests.reject.next().await.unwrap();
1366
1367    cx.run_until_parked();
1368
1369    assert_eq!(
1370        &reject_request.rejections,
1371        &[
1372            EditPredictionRejection {
1373                request_id: cancelled_id,
1374                reason: EditPredictionRejectReason::Canceled,
1375                was_shown: false,
1376                model_version: None,
1377            },
1378            EditPredictionRejection {
1379                request_id: first_id,
1380                reason: EditPredictionRejectReason::Replaced,
1381                was_shown: false,
1382                model_version: None,
1383            }
1384        ]
1385    );
1386}
1387
1388#[gpui::test]
1389async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
1390    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1391
1392    let fs = FakeFs::new(cx.executor());
1393    fs.insert_tree(
1394        "/root",
1395        json!({
1396            "foo.md":  "Hello!\nHow\nBye\n",
1397            "bar.md": "Hola!\nComo\nAdios\n"
1398        }),
1399    )
1400    .await;
1401    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1402
1403    let buffer = project
1404        .update(cx, |project, cx| {
1405            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1406            project.set_active_path(Some(path.clone()), cx);
1407            project.open_buffer(path, cx)
1408        })
1409        .await
1410        .unwrap();
1411    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1412    let position = snapshot.anchor_before(language::Point::new(1, 3));
1413
1414    ep_store.update(cx, |ep_store, cx| {
1415        ep_store.register_project(&project, cx);
1416        ep_store.register_buffer(&buffer, &project, cx);
1417    });
1418
1419    // First edit request - no prior edit, so not throttled.
1420    ep_store.update(cx, |ep_store, cx| {
1421        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1422    });
1423    let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
1424    edit_response_tx.send(empty_response()).unwrap();
1425    cx.run_until_parked();
1426
1427    let diagnostic = lsp::Diagnostic {
1428        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1429        severity: Some(lsp::DiagnosticSeverity::ERROR),
1430        message: "Sentence is incomplete".to_string(),
1431        ..Default::default()
1432    };
1433
1434    // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
1435    project.update(cx, |project, cx| {
1436        project.lsp_store().update(cx, |lsp_store, cx| {
1437            lsp_store
1438                .update_diagnostics(
1439                    LanguageServerId(0),
1440                    lsp::PublishDiagnosticsParams {
1441                        uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
1442                        diagnostics: vec![diagnostic],
1443                        version: None,
1444                    },
1445                    None,
1446                    language::DiagnosticSourceKind::Pushed,
1447                    &[],
1448                    cx,
1449                )
1450                .unwrap();
1451        });
1452    });
1453    let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
1454    jump_response_tx.send(empty_response()).unwrap();
1455    cx.run_until_parked();
1456
1457    // Second edit request - should be throttled by the first edit.
1458    ep_store.update(cx, |ep_store, cx| {
1459        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1460    });
1461    assert_no_predict_request_ready(&mut requests.predict);
1462
1463    // Second jump request - should be throttled by the first jump.
1464    ep_store.update(cx, |ep_store, cx| {
1465        ep_store.refresh_prediction_from_diagnostics(
1466            project.clone(),
1467            DiagnosticSearchScope::Global,
1468            cx,
1469        );
1470    });
1471    assert_no_predict_request_ready(&mut requests.predict);
1472
1473    // Wait for both throttles to expire.
1474    cx.background_executor
1475        .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
1476    cx.background_executor.run_until_parked();
1477    cx.run_until_parked();
1478
1479    // Both requests should now go through.
1480    let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
1481    response_tx_1.send(empty_response()).unwrap();
1482    cx.run_until_parked();
1483
1484    let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
1485    response_tx_2.send(empty_response()).unwrap();
1486    cx.run_until_parked();
1487}
1488
1489#[gpui::test]
1490async fn test_rejections_flushing(cx: &mut TestAppContext) {
1491    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1492
1493    ep_store.update(cx, |ep_store, cx| {
1494        ep_store.reject_prediction(
1495            EditPredictionId("test-1".into()),
1496            EditPredictionRejectReason::Discarded,
1497            false,
1498            None,
1499            cx,
1500        );
1501        ep_store.reject_prediction(
1502            EditPredictionId("test-2".into()),
1503            EditPredictionRejectReason::Canceled,
1504            true,
1505            None,
1506            cx,
1507        );
1508    });
1509
1510    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1511    cx.run_until_parked();
1512
1513    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1514    respond_tx.send(()).unwrap();
1515
1516    // batched
1517    assert_eq!(reject_request.rejections.len(), 2);
1518    assert_eq!(
1519        reject_request.rejections[0],
1520        EditPredictionRejection {
1521            request_id: "test-1".to_string(),
1522            reason: EditPredictionRejectReason::Discarded,
1523            was_shown: false,
1524            model_version: None,
1525        }
1526    );
1527    assert_eq!(
1528        reject_request.rejections[1],
1529        EditPredictionRejection {
1530            request_id: "test-2".to_string(),
1531            reason: EditPredictionRejectReason::Canceled,
1532            was_shown: true,
1533            model_version: None,
1534        }
1535    );
1536
1537    // Reaching batch size limit sends without debounce
1538    ep_store.update(cx, |ep_store, cx| {
1539        for i in 0..70 {
1540            ep_store.reject_prediction(
1541                EditPredictionId(format!("batch-{}", i).into()),
1542                EditPredictionRejectReason::Discarded,
1543                false,
1544                None,
1545                cx,
1546            );
1547        }
1548    });
1549
1550    // First MAX/2 items are sent immediately
1551    cx.run_until_parked();
1552    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1553    respond_tx.send(()).unwrap();
1554
1555    assert_eq!(reject_request.rejections.len(), 50);
1556    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1557    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1558
1559    // Remaining items are debounced with the next batch
1560    cx.executor().advance_clock(Duration::from_secs(15));
1561    cx.run_until_parked();
1562
1563    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1564    respond_tx.send(()).unwrap();
1565
1566    assert_eq!(reject_request.rejections.len(), 20);
1567    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1568    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1569
1570    // Request failure
1571    ep_store.update(cx, |ep_store, cx| {
1572        ep_store.reject_prediction(
1573            EditPredictionId("retry-1".into()),
1574            EditPredictionRejectReason::Discarded,
1575            false,
1576            None,
1577            cx,
1578        );
1579    });
1580
1581    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1582    cx.run_until_parked();
1583
1584    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1585    assert_eq!(reject_request.rejections.len(), 1);
1586    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1587    // Simulate failure
1588    drop(_respond_tx);
1589
1590    // Add another rejection
1591    ep_store.update(cx, |ep_store, cx| {
1592        ep_store.reject_prediction(
1593            EditPredictionId("retry-2".into()),
1594            EditPredictionRejectReason::Discarded,
1595            false,
1596            None,
1597            cx,
1598        );
1599    });
1600
1601    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1602    cx.run_until_parked();
1603
1604    // Retry should include both the failed item and the new one
1605    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1606    respond_tx.send(()).unwrap();
1607
1608    assert_eq!(reject_request.rejections.len(), 2);
1609    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1610    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1611}
1612
1613// Skipped until we start including diagnostics in prompt
1614// #[gpui::test]
1615// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1616//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1617//     let fs = FakeFs::new(cx.executor());
1618//     fs.insert_tree(
1619//         "/root",
1620//         json!({
1621//             "foo.md": "Hello!\nBye"
1622//         }),
1623//     )
1624//     .await;
1625//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1626
1627//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1628//     let diagnostic = lsp::Diagnostic {
1629//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1630//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1631//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1632//         ..Default::default()
1633//     };
1634
1635//     project.update(cx, |project, cx| {
1636//         project.lsp_store().update(cx, |lsp_store, cx| {
1637//             // Create some diagnostics
1638//             lsp_store
1639//                 .update_diagnostics(
1640//                     LanguageServerId(0),
1641//                     lsp::PublishDiagnosticsParams {
1642//                         uri: path_to_buffer_uri.clone(),
1643//                         diagnostics: vec![diagnostic],
1644//                         version: None,
1645//                     },
1646//                     None,
1647//                     language::DiagnosticSourceKind::Pushed,
1648//                     &[],
1649//                     cx,
1650//                 )
1651//                 .unwrap();
1652//         });
1653//     });
1654
1655//     let buffer = project
1656//         .update(cx, |project, cx| {
1657//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1658//             project.open_buffer(path, cx)
1659//         })
1660//         .await
1661//         .unwrap();
1662
1663//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1664//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1665
1666//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1667//         ep_store.request_prediction(&project, &buffer, position, cx)
1668//     });
1669
1670//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1671
1672//     assert_eq!(request.diagnostic_groups.len(), 1);
1673//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1674//         .unwrap();
1675//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1676//     assert_eq!(
1677//         value,
1678//         json!({
1679//             "entries": [{
1680//                 "range": {
1681//                     "start": 8,
1682//                     "end": 10
1683//                 },
1684//                 "diagnostic": {
1685//                     "source": null,
1686//                     "code": null,
1687//                     "code_description": null,
1688//                     "severity": 1,
1689//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1690//                     "markdown": null,
1691//                     "group_id": 0,
1692//                     "is_primary": true,
1693//                     "is_disk_based": false,
1694//                     "is_unnecessary": false,
1695//                     "source_kind": "Pushed",
1696//                     "data": null,
1697//                     "underline": true
1698//                 }
1699//             }],
1700//             "primary_ix": 0
1701//         })
1702//     );
1703// }
1704
1705// Generate a model response that would apply the given diff to the active file.
1706fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1707    let editable_range =
1708        zeta_prompt::excerpt_range_for_format(Default::default(), &request.input.excerpt_ranges).1;
1709    let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string();
1710    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1711
1712    PredictEditsV3Response {
1713        request_id: Uuid::new_v4().to_string(),
1714        editable_range,
1715        output: new_excerpt,
1716        model_version: None,
1717    }
1718}
1719
1720fn empty_response() -> PredictEditsV3Response {
1721    PredictEditsV3Response {
1722        request_id: Uuid::new_v4().to_string(),
1723        editable_range: 0..0,
1724        output: String::new(),
1725        model_version: None,
1726    }
1727}
1728
1729fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1730    zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1731}
1732
1733fn assert_no_predict_request_ready(
1734    requests: &mut mpsc::UnboundedReceiver<(
1735        PredictEditsV3Request,
1736        oneshot::Sender<PredictEditsV3Response>,
1737    )>,
1738) {
1739    if requests.next().now_or_never().flatten().is_some() {
1740        panic!("Unexpected prediction request while throttled.");
1741    }
1742}
1743
1744struct RequestChannels {
1745    predict: mpsc::UnboundedReceiver<(
1746        PredictEditsV3Request,
1747        oneshot::Sender<PredictEditsV3Response>,
1748    )>,
1749    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1750}
1751
1752fn init_test_with_fake_client(
1753    cx: &mut TestAppContext,
1754) -> (Entity<EditPredictionStore>, RequestChannels) {
1755    cx.update(move |cx| {
1756        let settings_store = SettingsStore::test(cx);
1757        cx.set_global(settings_store);
1758        zlog::init_test();
1759
1760        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1761        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1762
1763        let http_client = FakeHttpClient::create({
1764            move |req| {
1765                let uri = req.uri().path().to_string();
1766                let mut body = req.into_body();
1767                let predict_req_tx = predict_req_tx.clone();
1768                let reject_req_tx = reject_req_tx.clone();
1769                async move {
1770                    let resp = match uri.as_str() {
1771                        "/client/llm_tokens" => serde_json::to_string(&json!({
1772                            "token": "test"
1773                        }))
1774                        .unwrap(),
1775                        "/predict_edits/v3" => {
1776                            let mut buf = Vec::new();
1777                            body.read_to_end(&mut buf).await.ok();
1778                            let decompressed = zstd::decode_all(&buf[..]).unwrap();
1779                            let req = serde_json::from_slice(&decompressed).unwrap();
1780
1781                            let (res_tx, res_rx) = oneshot::channel();
1782                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1783                            serde_json::to_string(&res_rx.await?).unwrap()
1784                        }
1785                        "/predict_edits/reject" => {
1786                            let mut buf = Vec::new();
1787                            body.read_to_end(&mut buf).await.ok();
1788                            let req = serde_json::from_slice(&buf).unwrap();
1789
1790                            let (res_tx, res_rx) = oneshot::channel();
1791                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1792                            serde_json::to_string(&res_rx.await?).unwrap()
1793                        }
1794                        _ => {
1795                            panic!("Unexpected path: {}", uri)
1796                        }
1797                    };
1798
1799                    Ok(Response::builder().body(resp.into()).unwrap())
1800                }
1801            }
1802        });
1803
1804        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1805        client.cloud_client().set_credentials(1, "test".into());
1806
1807        language_model::init(client.clone(), cx);
1808
1809        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1810        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1811
1812        (
1813            ep_store,
1814            RequestChannels {
1815                predict: predict_req_rx,
1816                reject: reject_req_rx,
1817            },
1818        )
1819    })
1820}
1821
1822#[gpui::test]
1823async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1824    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1825    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1826        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1827    });
1828
1829    let edit_preview = cx
1830        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1831        .await;
1832
1833    let prediction = EditPrediction {
1834        edits,
1835        cursor_position: None,
1836        edit_preview,
1837        buffer: buffer.clone(),
1838        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1839        id: EditPredictionId("the-id".into()),
1840        inputs: ZetaPromptInput {
1841            events: Default::default(),
1842            related_files: Default::default(),
1843            cursor_path: Path::new("").into(),
1844            cursor_excerpt: "".into(),
1845            cursor_offset_in_excerpt: 0,
1846            excerpt_start_row: None,
1847            excerpt_ranges: Default::default(),
1848            experiment: None,
1849            in_open_source_repo: false,
1850            can_collect_data: false,
1851        },
1852        buffer_snapshotted_at: Instant::now(),
1853        response_received_at: Instant::now(),
1854        model_version: None,
1855    };
1856
1857    cx.update(|cx| {
1858        assert_eq!(
1859            from_completion_edits(
1860                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1861                &buffer,
1862                cx
1863            ),
1864            vec![(2..5, "REM".into()), (9..11, "".into())]
1865        );
1866
1867        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1868        assert_eq!(
1869            from_completion_edits(
1870                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1871                &buffer,
1872                cx
1873            ),
1874            vec![(2..2, "REM".into()), (6..8, "".into())]
1875        );
1876
1877        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1878        assert_eq!(
1879            from_completion_edits(
1880                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1881                &buffer,
1882                cx
1883            ),
1884            vec![(2..5, "REM".into()), (9..11, "".into())]
1885        );
1886
1887        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1888        assert_eq!(
1889            from_completion_edits(
1890                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1891                &buffer,
1892                cx
1893            ),
1894            vec![(3..3, "EM".into()), (7..9, "".into())]
1895        );
1896
1897        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1898        assert_eq!(
1899            from_completion_edits(
1900                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1901                &buffer,
1902                cx
1903            ),
1904            vec![(4..4, "M".into()), (8..10, "".into())]
1905        );
1906
1907        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1908        assert_eq!(
1909            from_completion_edits(
1910                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1911                &buffer,
1912                cx
1913            ),
1914            vec![(9..11, "".into())]
1915        );
1916
1917        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1918        assert_eq!(
1919            from_completion_edits(
1920                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1921                &buffer,
1922                cx
1923            ),
1924            vec![(4..4, "M".into()), (8..10, "".into())]
1925        );
1926
1927        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1928        assert_eq!(
1929            from_completion_edits(
1930                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1931                &buffer,
1932                cx
1933            ),
1934            vec![(4..4, "M".into())]
1935        );
1936
1937        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1938        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1939    })
1940}
1941
1942#[gpui::test]
1943async fn test_clean_up_diff(cx: &mut TestAppContext) {
1944    init_test(cx);
1945
1946    assert_eq!(
1947        apply_edit_prediction(
1948            indoc! {"
1949                    fn main() {
1950                        let word_1 = \"lorem\";
1951                        let range = word.len()..word.len();
1952                    }
1953                "},
1954            indoc! {"
1955                    fn main() {
1956                        let word_1 = \"lorem\";
1957                        let range = word_1.len()..word_1.len();
1958                    }
1959                "},
1960            cx,
1961        )
1962        .await,
1963        indoc! {"
1964                fn main() {
1965                    let word_1 = \"lorem\";
1966                    let range = word_1.len()..word_1.len();
1967                }
1968            "},
1969    );
1970
1971    assert_eq!(
1972        apply_edit_prediction(
1973            indoc! {"
1974                    fn main() {
1975                        let story = \"the quick\"
1976                    }
1977                "},
1978            indoc! {"
1979                    fn main() {
1980                        let story = \"the quick brown fox jumps over the lazy dog\";
1981                    }
1982                "},
1983            cx,
1984        )
1985        .await,
1986        indoc! {"
1987                fn main() {
1988                    let story = \"the quick brown fox jumps over the lazy dog\";
1989                }
1990            "},
1991    );
1992}
1993
1994#[gpui::test]
1995async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1996    init_test(cx);
1997
1998    let buffer_content = "lorem\n";
1999    let completion_response = "lorem\nipsum\n";
2000
2001    assert_eq!(
2002        apply_edit_prediction(buffer_content, completion_response, cx).await,
2003        "lorem\nipsum\n"
2004    );
2005}
2006
2007#[gpui::test]
2008async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
2009    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
2010    // When the buffer ends without a trailing newline, but the model returns output
2011    // with a trailing newline, zeta2 should normalize both sides before diffing
2012    // so no spurious newline is inserted.
2013    let (ep_store, mut requests) = init_test_with_fake_client(cx);
2014    let fs = FakeFs::new(cx.executor());
2015
2016    // Single line buffer with no trailing newline
2017    fs.insert_tree(
2018        "/root",
2019        json!({
2020            "foo.txt": "hello"
2021        }),
2022    )
2023    .await;
2024    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2025
2026    let buffer = project
2027        .update(cx, |project, cx| {
2028            let path = project
2029                .find_project_path(path!("root/foo.txt"), cx)
2030                .unwrap();
2031            project.open_buffer(path, cx)
2032        })
2033        .await
2034        .unwrap();
2035
2036    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2037    let position = snapshot.anchor_before(language::Point::new(0, 5));
2038
2039    ep_store.update(cx, |ep_store, cx| {
2040        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2041    });
2042
2043    let (request, respond_tx) = requests.predict.next().await.unwrap();
2044
2045    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
2046    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
2047    let excerpt_length = request.input.cursor_excerpt.len();
2048    let response = PredictEditsV3Response {
2049        request_id: Uuid::new_v4().to_string(),
2050        output: "hello world\n".to_string(),
2051        editable_range: 0..excerpt_length,
2052        model_version: None,
2053    };
2054    respond_tx.send(response).unwrap();
2055
2056    cx.run_until_parked();
2057
2058    // The prediction should insert " world" without adding a newline
2059    ep_store.update(cx, |ep_store, cx| {
2060        let prediction = ep_store
2061            .prediction_at(&buffer, None, &project, cx)
2062            .expect("should have prediction");
2063        let edits: Vec<_> = prediction
2064            .edits
2065            .iter()
2066            .map(|(range, text)| {
2067                let snapshot = buffer.read(cx).snapshot();
2068                (range.to_offset(&snapshot), text.clone())
2069            })
2070            .collect();
2071        assert_eq!(edits, vec![(5..5, " world".into())]);
2072    });
2073}
2074
2075fn init_test(cx: &mut TestAppContext) {
2076    cx.update(|cx| {
2077        let settings_store = SettingsStore::test(cx);
2078        cx.set_global(settings_store);
2079    });
2080}
2081
2082async fn apply_edit_prediction(
2083    buffer_content: &str,
2084    completion_response: &str,
2085    cx: &mut TestAppContext,
2086) -> String {
2087    let fs = project::FakeFs::new(cx.executor());
2088    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2089    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2090    let (ep_store, response) = make_test_ep_store(&project, cx).await;
2091    *response.lock() = completion_response.to_string();
2092    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2093    buffer.update(cx, |buffer, cx| {
2094        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2095    });
2096    buffer.read_with(cx, |buffer, _| buffer.text())
2097}
2098
2099async fn run_edit_prediction(
2100    buffer: &Entity<Buffer>,
2101    project: &Entity<Project>,
2102    ep_store: &Entity<EditPredictionStore>,
2103    cx: &mut TestAppContext,
2104) -> EditPrediction {
2105    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2106    ep_store.update(cx, |ep_store, cx| {
2107        ep_store.register_buffer(buffer, &project, cx)
2108    });
2109    cx.background_executor.run_until_parked();
2110    let prediction_task = ep_store.update(cx, |ep_store, cx| {
2111        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2112    });
2113    prediction_task.await.unwrap().unwrap().prediction.unwrap()
2114}
2115
2116async fn make_test_ep_store(
2117    project: &Entity<Project>,
2118    cx: &mut TestAppContext,
2119) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
2120    let default_response = "hello world\n".to_string();
2121    let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
2122    let http_client = FakeHttpClient::create({
2123        let completion_response = completion_response.clone();
2124        let mut next_request_id = 0;
2125        move |req| {
2126            let completion_response = completion_response.clone();
2127            let method = req.method().clone();
2128            let uri = req.uri().path().to_string();
2129            let mut body = req.into_body();
2130            async move {
2131                match (method, uri.as_str()) {
2132                    (Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2133                        .status(200)
2134                        .body(
2135                            serde_json::to_string(&CreateLlmTokenResponse {
2136                                token: LlmToken("the-llm-token".to_string()),
2137                            })
2138                            .unwrap()
2139                            .into(),
2140                        )
2141                        .unwrap()),
2142                    (Method::POST, "/predict_edits/v3") => {
2143                        let mut buf = Vec::new();
2144                        body.read_to_end(&mut buf).await.ok();
2145                        let decompressed = zstd::decode_all(&buf[..]).unwrap();
2146                        let req: PredictEditsV3Request =
2147                            serde_json::from_slice(&decompressed).unwrap();
2148
2149                        next_request_id += 1;
2150                        Ok(http_client::Response::builder()
2151                            .status(200)
2152                            .body(
2153                                serde_json::to_string(&PredictEditsV3Response {
2154                                    request_id: format!("request-{next_request_id}"),
2155                                    editable_range: 0..req.input.cursor_excerpt.len(),
2156                                    output: completion_response.lock().clone(),
2157                                    model_version: None,
2158                                })
2159                                .unwrap()
2160                                .into(),
2161                            )
2162                            .unwrap())
2163                    }
2164                    _ => Ok(http_client::Response::builder()
2165                        .status(404)
2166                        .body("Not Found".to_string().into())
2167                        .unwrap()),
2168                }
2169            }
2170        }
2171    });
2172
2173    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2174    cx.update(|cx| {
2175        RefreshLlmTokenListener::register(client.clone(), cx);
2176    });
2177    let _server = FakeServer::for_client(42, &client, cx).await;
2178
2179    let ep_store = cx.new(|cx| {
2180        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2181        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2182
2183        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2184        for worktree in worktrees {
2185            let worktree_id = worktree.read(cx).id();
2186            ep_store
2187                .get_or_init_project(project, cx)
2188                .license_detection_watchers
2189                .entry(worktree_id)
2190                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2191        }
2192
2193        ep_store
2194    });
2195
2196    (ep_store, completion_response)
2197}
2198
2199fn to_completion_edits(
2200    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2201    buffer: &Entity<Buffer>,
2202    cx: &App,
2203) -> Vec<(Range<Anchor>, Arc<str>)> {
2204    let buffer = buffer.read(cx);
2205    iterator
2206        .into_iter()
2207        .map(|(range, text)| {
2208            (
2209                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2210                text,
2211            )
2212        })
2213        .collect()
2214}
2215
2216fn from_completion_edits(
2217    editor_edits: &[(Range<Anchor>, Arc<str>)],
2218    buffer: &Entity<Buffer>,
2219    cx: &App,
2220) -> Vec<(Range<usize>, Arc<str>)> {
2221    let buffer = buffer.read(cx);
2222    editor_edits
2223        .iter()
2224        .map(|(range, text)| {
2225            (
2226                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2227                text.clone(),
2228            )
2229        })
2230        .collect()
2231}
2232
2233#[gpui::test]
2234async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2235    init_test(cx);
2236
2237    let fs = FakeFs::new(cx.executor());
2238    fs.insert_tree(
2239        "/project",
2240        serde_json::json!({
2241            "main.rs": "fn main() {\n    \n}\n"
2242        }),
2243    )
2244    .await;
2245
2246    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2247
2248    let http_client = FakeHttpClient::create(|_req| async move {
2249        Ok(gpui::http_client::Response::builder()
2250            .status(401)
2251            .body("Unauthorized".into())
2252            .unwrap())
2253    });
2254
2255    let client =
2256        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2257    cx.update(|cx| {
2258        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2259    });
2260
2261    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2262
2263    let buffer = project
2264        .update(cx, |project, cx| {
2265            let path = project
2266                .find_project_path(path!("/project/main.rs"), cx)
2267                .unwrap();
2268            project.open_buffer(path, cx)
2269        })
2270        .await
2271        .unwrap();
2272
2273    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2274    ep_store.update(cx, |ep_store, cx| {
2275        ep_store.register_buffer(&buffer, &project, cx)
2276    });
2277    cx.background_executor.run_until_parked();
2278
2279    let completion_task = ep_store.update(cx, |ep_store, cx| {
2280        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2281        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2282    });
2283
2284    let result = completion_task.await;
2285    assert!(
2286        result.is_err(),
2287        "Without authentication and without custom URL, prediction should fail"
2288    );
2289}
2290
2291#[gpui::test]
2292fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2293    let buffer = cx.new(|cx| {
2294        Buffer::local(
2295            indoc! {"
2296                zero
2297                one
2298                two
2299                three
2300                four
2301                five
2302                six
2303                seven
2304                eight
2305                nine
2306                ten
2307                eleven
2308                twelve
2309                thirteen
2310                fourteen
2311                fifteen
2312                sixteen
2313                seventeen
2314                eighteen
2315                nineteen
2316                twenty
2317                twenty-one
2318                twenty-two
2319                twenty-three
2320                twenty-four
2321            "},
2322            cx,
2323        )
2324    });
2325
2326    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2327
2328    buffer.update(cx, |buffer, cx| {
2329        let point = Point::new(12, 0);
2330        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2331        let point = Point::new(8, 0);
2332        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2333    });
2334
2335    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2336
2337    let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2338
2339    assert_eq!(
2340        diff,
2341        indoc! {"
2342            @@ -6,10 +6,12 @@
2343             five
2344             six
2345             seven
2346            +FIRST INSERTION
2347             eight
2348             nine
2349             ten
2350             eleven
2351            +SECOND INSERTION
2352             twelve
2353             thirteen
2354             fourteen
2355            "}
2356    );
2357}
2358
2359#[gpui::test]
2360async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
2361    fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
2362        let collab_replica = clock::ReplicaId::new(10);
2363        let anchor = buffer.read_with(cx, |buffer, _| {
2364            buffer.snapshot().anchor_before(Point::new(row, 0))
2365        });
2366        let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
2367            id: 1,
2368            start: anchor,
2369            end: anchor,
2370            reversed: false,
2371            goal: SelectionGoal::None,
2372        }]);
2373        buffer.update(cx, |buffer, cx| {
2374            buffer.apply_ops(
2375                [Operation::UpdateSelections {
2376                    selections,
2377                    lamport_timestamp: clock::Lamport {
2378                        replica_id: collab_replica,
2379                        value: 1,
2380                    },
2381                    line_mode: false,
2382                    cursor_shape: CursorShape::Bar,
2383                }],
2384                cx,
2385            );
2386        });
2387    }
2388
2389    fn publish_diagnostics(
2390        uri_path: &'static str,
2391        rows: &[u32],
2392        project: &Entity<Project>,
2393        cx: &mut TestAppContext,
2394    ) {
2395        let diagnostics: Vec<_> = rows
2396            .iter()
2397            .map(|&row| lsp::Diagnostic {
2398                range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
2399                severity: Some(lsp::DiagnosticSeverity::ERROR),
2400                message: format!("error at row {row}"),
2401                ..Default::default()
2402            })
2403            .collect();
2404        project.update(cx, |project, cx| {
2405            project.lsp_store().update(cx, |lsp_store, cx| {
2406                lsp_store
2407                    .update_diagnostics(
2408                        LanguageServerId(0),
2409                        lsp::PublishDiagnosticsParams {
2410                            uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
2411                            diagnostics,
2412                            version: None,
2413                        },
2414                        None,
2415                        language::DiagnosticSourceKind::Pushed,
2416                        &[],
2417                        cx,
2418                    )
2419                    .expect("failed to update diagnostics");
2420            });
2421        });
2422    }
2423
2424    init_test(cx);
2425
2426    let mut lines = String::new();
2427    for i in 0..60 {
2428        lines.push_str(&format!("line {i}\n"));
2429    }
2430
2431    let fs = FakeFs::new(cx.executor());
2432    fs.insert_tree(
2433        "/root",
2434        json!({
2435            "active.txt": lines,
2436            "collab_file.txt": "error here\nsecond line\n",
2437            "free_file.txt": "another error\nsecond line\n",
2438        }),
2439    )
2440    .await;
2441    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2442
2443    let active_buffer = project
2444        .update(cx, |project, cx| {
2445            let path = project
2446                .find_project_path(path!("/root/active.txt"), cx)
2447                .expect("active.txt not found");
2448            project.set_active_path(Some(path.clone()), cx);
2449            project.open_buffer(path, cx)
2450        })
2451        .await
2452        .expect("failed to open active buffer");
2453
2454    set_collaborator_cursor(&active_buffer, 5, cx);
2455
2456    publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
2457
2458    cx.run_until_parked();
2459
2460    let cursor_point = Point::new(25, 0);
2461    let empty_search_range: Range<Point> = Default::default();
2462
2463    let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2464    let result = EditPredictionStore::next_diagnostic_location(
2465        active_buffer.clone(),
2466        &snapshot,
2467        empty_search_range.clone(),
2468        cursor_point,
2469        &project,
2470        &mut cx.to_async(),
2471    )
2472    .await
2473    .expect("next_diagnostic_location failed");
2474
2475    let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
2476    assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
2477    let result_row = result_buffer.read_with(cx, |buffer, _| {
2478        result_anchor.to_point(&buffer.snapshot()).row
2479    });
2480    assert_ne!(
2481        result_row, 3,
2482        "row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
2483    );
2484    assert!(
2485        result_row == 25 || result_row == 50,
2486        "expected row 25 or 50, got {result_row}"
2487    );
2488
2489    let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2490    let near_cursor_point = Point::new(4, 0);
2491    let result_near = EditPredictionStore::next_diagnostic_location(
2492        active_buffer.clone(),
2493        &snapshot_near,
2494        empty_search_range.clone(),
2495        near_cursor_point,
2496        &project,
2497        &mut cx.to_async(),
2498    )
2499    .await
2500    .expect("next_diagnostic_location failed");
2501
2502    let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
2503    let near_row =
2504        active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
2505    assert_eq!(
2506        near_row, 3,
2507        "row 3 should be included when local cursor (row 4) is also near the collaborator"
2508    );
2509
2510    let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2511    let far_cursor_point = Point::new(50, 0);
2512    let result_far = EditPredictionStore::next_diagnostic_location(
2513        active_buffer.clone(),
2514        &snapshot_far,
2515        empty_search_range.clone(),
2516        far_cursor_point,
2517        &project,
2518        &mut cx.to_async(),
2519    )
2520    .await
2521    .expect("next_diagnostic_location failed");
2522
2523    let (_, far_anchor) = result_far.expect("expected a diagnostic location");
2524    let far_row =
2525        active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
2526    assert_eq!(
2527        far_row, 50,
2528        "row 50 is near local cursor (row 50) and far from collaborator, should be picked"
2529    );
2530
2531    publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
2532    publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
2533    cx.run_until_parked();
2534
2535    let collab_buffer = project
2536        .update(cx, |project, cx| {
2537            let path = project
2538                .find_project_path(path!("/root/collab_file.txt"), cx)
2539                .expect("collab_file.txt not found");
2540            project.open_buffer(path, cx)
2541        })
2542        .await
2543        .expect("failed to open collab buffer");
2544
2545    set_collaborator_cursor(&collab_buffer, 0, cx);
2546    cx.run_until_parked();
2547
2548    let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
2549    let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2550    let result_cross = EditPredictionStore::next_diagnostic_location(
2551        active_buffer.clone(),
2552        &snapshot_cross,
2553        no_same_file_search_range,
2554        Point::new(0, 0),
2555        &project,
2556        &mut cx.to_async(),
2557    )
2558    .await
2559    .expect("cross-file next_diagnostic_location failed");
2560
2561    let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
2562    let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
2563        buffer
2564            .file()
2565            .expect("buffer should have a file")
2566            .full_path(cx)
2567    });
2568    assert_eq!(
2569        cross_path,
2570        Path::new(path!("root/free_file.txt")),
2571        "should skip collab_file.txt (has collaborator) and pick free_file.txt"
2572    );
2573}
2574
2575#[gpui::test]
2576async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
2577    let (ep_store, _requests) = init_test_with_fake_client(cx);
2578    let fs = FakeFs::new(cx.executor());
2579
2580    // Buffer with two clearly separated regions:
2581    //   Region A = lines 0-9   (offsets 0..50)
2582    //   Region B = lines 20-29 (offsets 105..155)
2583    // A big gap in between so edits in one region never overlap the other.
2584    let mut content = String::new();
2585    for i in 0..30 {
2586        content.push_str(&format!("line {i:02}\n"));
2587    }
2588
2589    fs.insert_tree(
2590        "/root",
2591        json!({
2592            "foo.md": content.clone()
2593        }),
2594    )
2595    .await;
2596    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2597
2598    let buffer = project
2599        .update(cx, |project, cx| {
2600            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2601            project.open_buffer(path, cx)
2602        })
2603        .await
2604        .unwrap();
2605
2606    let settled_events: Arc<Mutex<Vec<(EditPredictionId, String)>>> =
2607        Arc::new(Mutex::new(Vec::new()));
2608
2609    ep_store.update(cx, |ep_store, cx| {
2610        ep_store.register_buffer(&buffer, &project, cx);
2611
2612        let settled_events = settled_events.clone();
2613        ep_store.settled_event_callback = Some(Box::new(move |id, text| {
2614            settled_events.lock().push((id, text));
2615        }));
2616    });
2617
2618    // --- Phase 1: edit in region A and enqueue prediction A ---
2619
2620    buffer.update(cx, |buffer, cx| {
2621        // Edit at the start of line 0.
2622        buffer.edit(vec![(0..0, "ADDED ")], None, cx);
2623    });
2624    cx.run_until_parked();
2625
2626    let snapshot_a = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2627
2628    // Region A: first 10 lines of the buffer.
2629    let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0));
2630    ep_store.update(cx, |ep_store, cx| {
2631        ep_store.enqueue_settled_prediction(
2632            EditPredictionId("prediction-a".into()),
2633            &project,
2634            &buffer,
2635            &snapshot_a,
2636            editable_region_a,
2637            cx,
2638        );
2639    });
2640
2641    // --- Phase 2: repeatedly edit in region A to keep it unsettled ---
2642
2643    // Let the worker process the channel message before we start advancing.
2644    cx.run_until_parked();
2645
2646    let mut region_a_edit_offset = 5;
2647    for _ in 0..3 {
2648        // Edit inside region A (not at the boundary) so `last_edit_at` is
2649        // updated before the worker's next wake.
2650        buffer.update(cx, |buffer, cx| {
2651            buffer.edit(
2652                vec![(region_a_edit_offset..region_a_edit_offset, "x")],
2653                None,
2654                cx,
2655            );
2656        });
2657        region_a_edit_offset += 1;
2658        cx.run_until_parked();
2659
2660        cx.executor()
2661            .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 2);
2662        cx.run_until_parked();
2663        assert!(
2664            settled_events.lock().is_empty(),
2665            "no settled events should fire while region A is still being edited"
2666        );
2667    }
2668
2669    // Still nothing settled.
2670    assert!(settled_events.lock().is_empty());
2671
2672    // --- Phase 3: edit in distinct region B, enqueue prediction B ---
2673    // Advance a small amount so B's quiescence window starts later than A's,
2674    // but not so much that A settles (A's last edit was at the start of
2675    // iteration 3, and it needs a full Q to settle).
2676    cx.executor()
2677        .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2678    cx.run_until_parked();
2679    assert!(settled_events.lock().is_empty());
2680
2681    let snapshot_b = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2682    let line_20_offset = snapshot_b.point_to_offset(Point::new(20, 0));
2683
2684    buffer.update(cx, |buffer, cx| {
2685        buffer.edit(vec![(line_20_offset..line_20_offset, "NEW ")], None, cx);
2686    });
2687    cx.run_until_parked();
2688
2689    let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2690    let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0));
2691    ep_store.update(cx, |ep_store, cx| {
2692        ep_store.enqueue_settled_prediction(
2693            EditPredictionId("prediction-b".into()),
2694            &project,
2695            &buffer,
2696            &snapshot_b2,
2697            editable_region_b,
2698            cx,
2699        );
2700    });
2701
2702    cx.run_until_parked();
2703    assert!(
2704        settled_events.lock().is_empty(),
2705        "neither prediction should have settled yet"
2706    );
2707
2708    // --- Phase 4: let enough time pass for region A to settle ---
2709    // A's last edit was at T_a (during the last loop iteration). The worker is
2710    // sleeping until T_a + Q. We advance just enough to reach that wake time
2711    // (Q/4 since we already advanced Q/4 in phase 3 on top of the loop's
2712    // 3*Q/2). At that point A has been quiet for Q and settles, but B was
2713    // enqueued only Q/4 ago and stays pending.
2714    cx.executor()
2715        .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2716    cx.run_until_parked();
2717
2718    {
2719        let events = settled_events.lock().clone();
2720        assert_eq!(
2721            events.len(),
2722            1,
2723            "only prediction A should have settled, got: {events:?}"
2724        );
2725        assert_eq!(events[0].0, EditPredictionId("prediction-a".into()));
2726    }
2727
2728    // --- Phase 5: let more time pass for region B to settle ---
2729    // B's last edit was Q/4 before A settled. The worker rescheduled to
2730    // B's last_edit_at + Q, which is 3Q/4 from now.
2731    cx.executor()
2732        .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE * 3 / 4);
2733    cx.run_until_parked();
2734
2735    {
2736        let events = settled_events.lock().clone();
2737        assert_eq!(
2738            events.len(),
2739            2,
2740            "both predictions should have settled, got: {events:?}"
2741        );
2742        assert_eq!(events[1].0, EditPredictionId("prediction-b".into()));
2743    }
2744}
2745
2746#[ctor::ctor]
2747fn init_logger() {
2748    zlog::init_test();
2749}