edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
   3use client::{UserStore, test::FakeServer};
   4use clock::FakeSystemClock;
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
   8    predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
   9};
  10use futures::{
  11    AsyncReadExt, FutureExt, StreamExt,
  12    channel::{mpsc, oneshot},
  13};
  14use gpui::App;
  15use gpui::{
  16    Entity, TestAppContext,
  17    http_client::{FakeHttpClient, Response},
  18};
  19use indoc::indoc;
  20use language::{Anchor, Buffer, CursorShape, Operation, Point, Selection, SelectionGoal};
  21use lsp::LanguageServerId;
  22use parking_lot::Mutex;
  23use pretty_assertions::{assert_eq, assert_matches};
  24use project::{FakeFs, Project};
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{path::Path, sync::Arc, time::Duration};
  28use util::path;
  29use uuid::Uuid;
  30use zeta_prompt::ZetaPromptInput;
  31
  32use crate::{
  33    BufferEditPrediction, EDIT_PREDICTION_SETTLED_QUIESCENCE, EditPredictionId,
  34    EditPredictionStore, REJECT_REQUEST_DEBOUNCE,
  35};
  36
  37#[gpui::test]
  38async fn test_current_state(cx: &mut TestAppContext) {
  39    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  40    let fs = FakeFs::new(cx.executor());
  41    fs.insert_tree(
  42        "/root",
  43        json!({
  44            "1.txt": "Hello!\nHow\nBye\n",
  45            "2.txt": "Hola!\nComo\nAdios\n"
  46        }),
  47    )
  48    .await;
  49    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  50
  51    let buffer1 = project
  52        .update(cx, |project, cx| {
  53            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  54            project.set_active_path(Some(path.clone()), cx);
  55            project.open_buffer(path, cx)
  56        })
  57        .await
  58        .unwrap();
  59    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  60    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  61
  62    ep_store.update(cx, |ep_store, cx| {
  63        ep_store.register_project(&project, cx);
  64        ep_store.register_buffer(&buffer1, &project, cx);
  65    });
  66
  67    // Prediction for current file
  68
  69    ep_store.update(cx, |ep_store, cx| {
  70        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  71    });
  72    let (request, respond_tx) = requests.predict.next().await.unwrap();
  73
  74    respond_tx
  75        .send(model_response(
  76            &request,
  77            indoc! {r"
  78                --- a/root/1.txt
  79                +++ b/root/1.txt
  80                @@ ... @@
  81                 Hello!
  82                -How
  83                +How are you?
  84                 Bye
  85            "},
  86        ))
  87        .unwrap();
  88
  89    cx.run_until_parked();
  90
  91    ep_store.update(cx, |ep_store, cx| {
  92        let prediction = ep_store
  93            .prediction_at(&buffer1, None, &project, cx)
  94            .unwrap();
  95        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  96    });
  97
  98    ep_store.update(cx, |ep_store, cx| {
  99        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
 100    });
 101
 102    // Prediction for diagnostic in another file
 103
 104    let diagnostic = lsp::Diagnostic {
 105        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 106        severity: Some(lsp::DiagnosticSeverity::ERROR),
 107        message: "Sentence is incomplete".to_string(),
 108        ..Default::default()
 109    };
 110
 111    project.update(cx, |project, cx| {
 112        project.lsp_store().update(cx, |lsp_store, cx| {
 113            lsp_store
 114                .update_diagnostics(
 115                    LanguageServerId(0),
 116                    lsp::PublishDiagnosticsParams {
 117                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 118                        diagnostics: vec![diagnostic],
 119                        version: None,
 120                    },
 121                    None,
 122                    language::DiagnosticSourceKind::Pushed,
 123                    &[],
 124                    cx,
 125                )
 126                .unwrap();
 127        });
 128    });
 129
 130    let (request, respond_tx) = requests.predict.next().await.unwrap();
 131    respond_tx
 132        .send(model_response(
 133            &request,
 134            indoc! {r#"
 135                --- a/root/2.txt
 136                +++ b/root/2.txt
 137                @@ ... @@
 138                 Hola!
 139                -Como
 140                +Como estas?
 141                 Adios
 142            "#},
 143        ))
 144        .unwrap();
 145    cx.run_until_parked();
 146
 147    ep_store.update(cx, |ep_store, cx| {
 148        let prediction = ep_store
 149            .prediction_at(&buffer1, None, &project, cx)
 150            .unwrap();
 151        assert_matches!(
 152            prediction,
 153            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 154        );
 155    });
 156
 157    let buffer2 = project
 158        .update(cx, |project, cx| {
 159            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 160            project.open_buffer(path, cx)
 161        })
 162        .await
 163        .unwrap();
 164
 165    ep_store.update(cx, |ep_store, cx| {
 166        let prediction = ep_store
 167            .prediction_at(&buffer2, None, &project, cx)
 168            .unwrap();
 169        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 170    });
 171}
 172
 173#[gpui::test]
 174async fn test_simple_request(cx: &mut TestAppContext) {
 175    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 176    let fs = FakeFs::new(cx.executor());
 177    fs.insert_tree(
 178        "/root",
 179        json!({
 180            "foo.md":  "Hello!\nHow\nBye\n"
 181        }),
 182    )
 183    .await;
 184    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 185
 186    let buffer = project
 187        .update(cx, |project, cx| {
 188            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 189            project.open_buffer(path, cx)
 190        })
 191        .await
 192        .unwrap();
 193    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 194    let position = snapshot.anchor_before(language::Point::new(1, 3));
 195
 196    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 197        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 198    });
 199
 200    let (request, respond_tx) = requests.predict.next().await.unwrap();
 201
 202    // TODO Put back when we have a structured request again
 203    // assert_eq!(
 204    //     request.excerpt_path.as_ref(),
 205    //     Path::new(path!("root/foo.md"))
 206    // );
 207    // assert_eq!(
 208    //     request.cursor_point,
 209    //     Point {
 210    //         line: Line(1),
 211    //         column: 3
 212    //     }
 213    // );
 214
 215    respond_tx
 216        .send(model_response(
 217            &request,
 218            indoc! { r"
 219                --- a/root/foo.md
 220                +++ b/root/foo.md
 221                @@ ... @@
 222                 Hello!
 223                -How
 224                +How are you?
 225                 Bye
 226            "},
 227        ))
 228        .unwrap();
 229
 230    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 231
 232    assert_eq!(prediction.edits.len(), 1);
 233    assert_eq!(
 234        prediction.edits[0].0.to_point(&snapshot).start,
 235        language::Point::new(1, 3)
 236    );
 237    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 238}
 239
 240#[gpui::test]
 241async fn test_request_events(cx: &mut TestAppContext) {
 242    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 243    let fs = FakeFs::new(cx.executor());
 244    fs.insert_tree(
 245        "/root",
 246        json!({
 247            "foo.md": "Hello!\n\nBye\n"
 248        }),
 249    )
 250    .await;
 251    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 252
 253    let buffer = project
 254        .update(cx, |project, cx| {
 255            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 256            project.open_buffer(path, cx)
 257        })
 258        .await
 259        .unwrap();
 260
 261    ep_store.update(cx, |ep_store, cx| {
 262        ep_store.register_buffer(&buffer, &project, cx);
 263    });
 264
 265    buffer.update(cx, |buffer, cx| {
 266        buffer.edit(vec![(7..7, "How")], None, cx);
 267    });
 268
 269    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 270    let position = snapshot.anchor_before(language::Point::new(1, 3));
 271
 272    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 273        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 274    });
 275
 276    let (request, respond_tx) = requests.predict.next().await.unwrap();
 277
 278    let prompt = prompt_from_request(&request);
 279    assert!(
 280        prompt.contains(indoc! {"
 281        --- a/root/foo.md
 282        +++ b/root/foo.md
 283        @@ -1,3 +1,3 @@
 284         Hello!
 285        -
 286        +How
 287         Bye
 288    "}),
 289        "{prompt}"
 290    );
 291
 292    respond_tx
 293        .send(model_response(
 294            &request,
 295            indoc! {r#"
 296                --- a/root/foo.md
 297                +++ b/root/foo.md
 298                @@ ... @@
 299                 Hello!
 300                -How
 301                +How are you?
 302                 Bye
 303        "#},
 304        ))
 305        .unwrap();
 306
 307    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 308
 309    assert_eq!(prediction.edits.len(), 1);
 310    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 311}
 312
 313#[gpui::test]
 314async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 315    let (ep_store, _requests) = init_test_with_fake_client(cx);
 316    let fs = FakeFs::new(cx.executor());
 317    fs.insert_tree(
 318        "/root",
 319        json!({
 320            "foo.md": "Hello!\n\nBye\n"
 321        }),
 322    )
 323    .await;
 324    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 325
 326    let buffer = project
 327        .update(cx, |project, cx| {
 328            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 329            project.open_buffer(path, cx)
 330        })
 331        .await
 332        .unwrap();
 333
 334    ep_store.update(cx, |ep_store, cx| {
 335        ep_store.register_buffer(&buffer, &project, cx);
 336    });
 337
 338    // First burst: insert "How"
 339    buffer.update(cx, |buffer, cx| {
 340        buffer.edit(vec![(7..7, "How")], None, cx);
 341    });
 342
 343    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 344    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 345    cx.run_until_parked();
 346
 347    // Second burst: append " are you?" immediately after "How" on the same line.
 348    //
 349    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 350    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 351    buffer.update(cx, |buffer, cx| {
 352        buffer.edit(vec![(10..10, " are you?")], None, cx);
 353    });
 354
 355    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 356    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 357    buffer.update(cx, |buffer, cx| {
 358        buffer.edit(vec![(19..19, "!")], None, cx);
 359    });
 360
 361    // With time-based splitting, there are two distinct events.
 362    let events = ep_store.update(cx, |ep_store, cx| {
 363        ep_store.edit_history_for_project(&project, cx)
 364    });
 365    assert_eq!(events.len(), 2);
 366    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 367    assert_eq!(
 368        diff.as_str(),
 369        indoc! {"
 370            @@ -1,3 +1,3 @@
 371             Hello!
 372            -
 373            +How
 374             Bye
 375        "}
 376    );
 377
 378    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 379    assert_eq!(
 380        diff.as_str(),
 381        indoc! {"
 382            @@ -1,3 +1,3 @@
 383             Hello!
 384            -How
 385            +How are you?!
 386             Bye
 387        "}
 388    );
 389}
 390
 391#[gpui::test]
 392async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
 393    let (ep_store, _requests) = init_test_with_fake_client(cx);
 394    let fs = FakeFs::new(cx.executor());
 395
 396    // Create a file with 30 lines to test line-based coalescing
 397    let content = (1..=30)
 398        .map(|i| format!("Line {}\n", i))
 399        .collect::<String>();
 400    fs.insert_tree(
 401        "/root",
 402        json!({
 403            "foo.md": content
 404        }),
 405    )
 406    .await;
 407    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 408
 409    let buffer = project
 410        .update(cx, |project, cx| {
 411            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 412            project.open_buffer(path, cx)
 413        })
 414        .await
 415        .unwrap();
 416
 417    ep_store.update(cx, |ep_store, cx| {
 418        ep_store.register_buffer(&buffer, &project, cx);
 419    });
 420
 421    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 422    buffer.update(cx, |buffer, cx| {
 423        let start = Point::new(10, 0).to_offset(buffer);
 424        let end = Point::new(13, 0).to_offset(buffer);
 425        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 426    });
 427
 428    let events = ep_store.update(cx, |ep_store, cx| {
 429        ep_store.edit_history_for_project(&project, cx)
 430    });
 431    assert_eq!(
 432        render_events(&events),
 433        indoc! {"
 434            @@ -8,9 +8,8 @@
 435             Line 8
 436             Line 9
 437             Line 10
 438            -Line 11
 439            -Line 12
 440            -Line 13
 441            +Middle A
 442            +Middle B
 443             Line 14
 444             Line 15
 445             Line 16
 446        "},
 447        "After first edit"
 448    );
 449
 450    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 451    // This tests that coalescing considers the START of the existing range
 452    buffer.update(cx, |buffer, cx| {
 453        let offset = Point::new(5, 0).to_offset(buffer);
 454        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 455    });
 456
 457    let events = ep_store.update(cx, |ep_store, cx| {
 458        ep_store.edit_history_for_project(&project, cx)
 459    });
 460    assert_eq!(
 461        render_events(&events),
 462        indoc! {"
 463            @@ -3,14 +3,14 @@
 464             Line 3
 465             Line 4
 466             Line 5
 467            +Above
 468             Line 6
 469             Line 7
 470             Line 8
 471             Line 9
 472             Line 10
 473            -Line 11
 474            -Line 12
 475            -Line 13
 476            +Middle A
 477            +Middle B
 478             Line 14
 479             Line 15
 480             Line 16
 481        "},
 482        "After inserting above (should coalesce)"
 483    );
 484
 485    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 486    // This tests that coalescing considers the END of the existing range
 487    buffer.update(cx, |buffer, cx| {
 488        let offset = Point::new(14, 0).to_offset(buffer);
 489        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 490    });
 491
 492    let events = ep_store.update(cx, |ep_store, cx| {
 493        ep_store.edit_history_for_project(&project, cx)
 494    });
 495    assert_eq!(
 496        render_events(&events),
 497        indoc! {"
 498            @@ -3,15 +3,16 @@
 499             Line 3
 500             Line 4
 501             Line 5
 502            +Above
 503             Line 6
 504             Line 7
 505             Line 8
 506             Line 9
 507             Line 10
 508            -Line 11
 509            -Line 12
 510            -Line 13
 511            +Middle A
 512            +Middle B
 513             Line 14
 514            +Below
 515             Line 15
 516             Line 16
 517             Line 17
 518        "},
 519        "After inserting below (should coalesce)"
 520    );
 521
 522    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 523    // This should NOT coalesce - creates a new event
 524    buffer.update(cx, |buffer, cx| {
 525        let offset = Point::new(25, 0).to_offset(buffer);
 526        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 527    });
 528
 529    let events = ep_store.update(cx, |ep_store, cx| {
 530        ep_store.edit_history_for_project(&project, cx)
 531    });
 532    assert_eq!(
 533        render_events(&events),
 534        indoc! {"
 535            @@ -3,15 +3,16 @@
 536             Line 3
 537             Line 4
 538             Line 5
 539            +Above
 540             Line 6
 541             Line 7
 542             Line 8
 543             Line 9
 544             Line 10
 545            -Line 11
 546            -Line 12
 547            -Line 13
 548            +Middle A
 549            +Middle B
 550             Line 14
 551            +Below
 552             Line 15
 553             Line 16
 554             Line 17
 555
 556            ---
 557            @@ -23,6 +23,7 @@
 558             Line 22
 559             Line 23
 560             Line 24
 561            +Far below
 562             Line 25
 563             Line 26
 564             Line 27
 565        "},
 566        "After inserting far below (should NOT coalesce)"
 567    );
 568}
 569
 570fn render_events(events: &[StoredEvent]) -> String {
 571    events
 572        .iter()
 573        .map(|e| {
 574            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 575            diff.as_str()
 576        })
 577        .collect::<Vec<_>>()
 578        .join("\n---\n")
 579}
 580
 581fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
 582    events
 583        .iter()
 584        .map(|e| {
 585            let zeta_prompt::Event::BufferChange {
 586                diff, predicted, ..
 587            } = e.event.as_ref();
 588            let prefix = if *predicted { "predicted" } else { "manual" };
 589            format!("{}\n{}", prefix, diff)
 590        })
 591        .collect()
 592}
 593
 594#[gpui::test]
 595async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
 596    let (ep_store, _requests) = init_test_with_fake_client(cx);
 597    let fs = FakeFs::new(cx.executor());
 598    fs.insert_tree(
 599        "/root",
 600        json!({
 601            "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
 602        }),
 603    )
 604    .await;
 605    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 606
 607    let buffer = project
 608        .update(cx, |project, cx| {
 609            let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
 610            project.open_buffer(path, cx)
 611        })
 612        .await
 613        .unwrap();
 614
 615    ep_store.update(cx, |ep_store, cx| {
 616        ep_store.register_buffer(&buffer, &project, cx);
 617    });
 618
 619    // Case 1: Manual edits have `predicted` set to false.
 620    buffer.update(cx, |buffer, cx| {
 621        buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
 622    });
 623
 624    let events = ep_store.update(cx, |ep_store, cx| {
 625        ep_store.edit_history_for_project(&project, cx)
 626    });
 627
 628    assert_eq!(
 629        render_events_with_predicted(&events),
 630        vec![indoc! {"
 631            manual
 632            @@ -1,4 +1,4 @@
 633            -line 0
 634            +LINE ZERO
 635             line 1
 636             line 2
 637             line 3
 638        "}]
 639    );
 640
 641    // Case 2: Multiple successive manual edits near each other are merged into one
 642    // event with `predicted` set to false.
 643    buffer.update(cx, |buffer, cx| {
 644        let offset = Point::new(1, 0).to_offset(buffer);
 645        let end = Point::new(1, 6).to_offset(buffer);
 646        buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
 647    });
 648
 649    let events = ep_store.update(cx, |ep_store, cx| {
 650        ep_store.edit_history_for_project(&project, cx)
 651    });
 652    assert_eq!(
 653        render_events_with_predicted(&events),
 654        vec![indoc! {"
 655            manual
 656            @@ -1,5 +1,5 @@
 657            -line 0
 658            -line 1
 659            +LINE ZERO
 660            +LINE ONE
 661             line 2
 662             line 3
 663             line 4
 664        "}]
 665    );
 666
 667    // Case 3: Accepted predictions have `predicted` set to true.
 668    // Case 5: A manual edit that follows a predicted edit is not merged with the
 669    // predicted edit, even if it is nearby.
 670    ep_store.update(cx, |ep_store, cx| {
 671        buffer.update(cx, |buffer, cx| {
 672            let offset = Point::new(2, 0).to_offset(buffer);
 673            let end = Point::new(2, 6).to_offset(buffer);
 674            buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
 675        });
 676        ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
 677    });
 678
 679    let events = ep_store.update(cx, |ep_store, cx| {
 680        ep_store.edit_history_for_project(&project, cx)
 681    });
 682    assert_eq!(
 683        render_events_with_predicted(&events),
 684        vec![
 685            indoc! {"
 686                manual
 687                @@ -1,5 +1,5 @@
 688                -line 0
 689                -line 1
 690                +LINE ZERO
 691                +LINE ONE
 692                 line 2
 693                 line 3
 694                 line 4
 695            "},
 696            indoc! {"
 697                predicted
 698                @@ -1,6 +1,6 @@
 699                 LINE ZERO
 700                 LINE ONE
 701                -line 2
 702                +LINE TWO
 703                 line 3
 704                 line 4
 705                 line 5
 706            "}
 707        ]
 708    );
 709
 710    // Case 4: Multiple successive accepted predictions near each other are merged
 711    // into one event with `predicted` set to true.
 712    ep_store.update(cx, |ep_store, cx| {
 713        buffer.update(cx, |buffer, cx| {
 714            let offset = Point::new(3, 0).to_offset(buffer);
 715            let end = Point::new(3, 6).to_offset(buffer);
 716            buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
 717        });
 718        ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
 719    });
 720
 721    let events = ep_store.update(cx, |ep_store, cx| {
 722        ep_store.edit_history_for_project(&project, cx)
 723    });
 724    assert_eq!(
 725        render_events_with_predicted(&events),
 726        vec![
 727            indoc! {"
 728                manual
 729                @@ -1,5 +1,5 @@
 730                -line 0
 731                -line 1
 732                +LINE ZERO
 733                +LINE ONE
 734                 line 2
 735                 line 3
 736                 line 4
 737            "},
 738            indoc! {"
 739                predicted
 740                @@ -1,7 +1,7 @@
 741                 LINE ZERO
 742                 LINE ONE
 743                -line 2
 744                -line 3
 745                +LINE TWO
 746                +LINE THREE
 747                 line 4
 748                 line 5
 749                 line 6
 750            "}
 751        ]
 752    );
 753
 754    // Case 5 (continued): A manual edit that follows a predicted edit is not merged
 755    // with the predicted edit, even if it is nearby.
 756    buffer.update(cx, |buffer, cx| {
 757        let offset = Point::new(4, 0).to_offset(buffer);
 758        let end = Point::new(4, 6).to_offset(buffer);
 759        buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
 760    });
 761
 762    let events = ep_store.update(cx, |ep_store, cx| {
 763        ep_store.edit_history_for_project(&project, cx)
 764    });
 765    assert_eq!(
 766        render_events_with_predicted(&events),
 767        vec![
 768            indoc! {"
 769                manual
 770                @@ -1,5 +1,5 @@
 771                -line 0
 772                -line 1
 773                +LINE ZERO
 774                +LINE ONE
 775                 line 2
 776                 line 3
 777                 line 4
 778            "},
 779            indoc! {"
 780                predicted
 781                @@ -1,7 +1,7 @@
 782                 LINE ZERO
 783                 LINE ONE
 784                -line 2
 785                -line 3
 786                +LINE TWO
 787                +LINE THREE
 788                 line 4
 789                 line 5
 790                 line 6
 791            "},
 792            indoc! {"
 793                manual
 794                @@ -2,7 +2,7 @@
 795                 LINE ONE
 796                 LINE TWO
 797                 LINE THREE
 798                -line 4
 799                +LINE FOUR
 800                 line 5
 801                 line 6
 802                 line 7
 803            "}
 804        ]
 805    );
 806
 807    // Case 6: If we then perform a manual edit at a *different* location (more than
 808    // 8 lines away), then the edits at the prior location can be merged with each
 809    // other, even if some are predicted and some are not. `predicted` means all
 810    // constituent edits were predicted.
 811    buffer.update(cx, |buffer, cx| {
 812        let offset = Point::new(14, 0).to_offset(buffer);
 813        let end = Point::new(14, 7).to_offset(buffer);
 814        buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
 815    });
 816
 817    let events = ep_store.update(cx, |ep_store, cx| {
 818        ep_store.edit_history_for_project(&project, cx)
 819    });
 820    assert_eq!(
 821        render_events_with_predicted(&events),
 822        vec![
 823            indoc! {"
 824                manual
 825                @@ -1,8 +1,8 @@
 826                -line 0
 827                -line 1
 828                -line 2
 829                -line 3
 830                -line 4
 831                +LINE ZERO
 832                +LINE ONE
 833                +LINE TWO
 834                +LINE THREE
 835                +LINE FOUR
 836                 line 5
 837                 line 6
 838                 line 7
 839            "},
 840            indoc! {"
 841                manual
 842                @@ -12,4 +12,4 @@
 843                 line 11
 844                 line 12
 845                 line 13
 846                -line 14
 847                +LINE FOURTEEN
 848            "}
 849        ]
 850    );
 851}
 852
 853#[gpui::test]
 854async fn test_empty_prediction(cx: &mut TestAppContext) {
 855    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 856    let fs = FakeFs::new(cx.executor());
 857    fs.insert_tree(
 858        "/root",
 859        json!({
 860            "foo.md":  "Hello!\nHow\nBye\n"
 861        }),
 862    )
 863    .await;
 864    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 865
 866    let buffer = project
 867        .update(cx, |project, cx| {
 868            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 869            project.open_buffer(path, cx)
 870        })
 871        .await
 872        .unwrap();
 873    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 874    let position = snapshot.anchor_before(language::Point::new(1, 3));
 875
 876    ep_store.update(cx, |ep_store, cx| {
 877        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 878    });
 879
 880    let (request, respond_tx) = requests.predict.next().await.unwrap();
 881    let response = model_response(&request, "");
 882    let id = response.request_id.clone();
 883    respond_tx.send(response).unwrap();
 884
 885    cx.run_until_parked();
 886
 887    ep_store.update(cx, |ep_store, cx| {
 888        assert!(
 889            ep_store
 890                .prediction_at(&buffer, None, &project, cx)
 891                .is_none()
 892        );
 893    });
 894
 895    // prediction is reported as rejected
 896    let (reject_request, _) = requests.reject.next().await.unwrap();
 897
 898    assert_eq!(
 899        &reject_request.rejections,
 900        &[EditPredictionRejection {
 901            request_id: id,
 902            reason: EditPredictionRejectReason::Empty,
 903            was_shown: false,
 904            model_version: None,
 905        }]
 906    );
 907}
 908
 909#[gpui::test]
 910async fn test_interpolated_empty(cx: &mut TestAppContext) {
 911    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 912    let fs = FakeFs::new(cx.executor());
 913    fs.insert_tree(
 914        "/root",
 915        json!({
 916            "foo.md":  "Hello!\nHow\nBye\n"
 917        }),
 918    )
 919    .await;
 920    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 921
 922    let buffer = project
 923        .update(cx, |project, cx| {
 924            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 925            project.open_buffer(path, cx)
 926        })
 927        .await
 928        .unwrap();
 929    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 930    let position = snapshot.anchor_before(language::Point::new(1, 3));
 931
 932    ep_store.update(cx, |ep_store, cx| {
 933        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 934    });
 935
 936    let (request, respond_tx) = requests.predict.next().await.unwrap();
 937
 938    buffer.update(cx, |buffer, cx| {
 939        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 940    });
 941
 942    let response = model_response(&request, SIMPLE_DIFF);
 943    let id = response.request_id.clone();
 944    respond_tx.send(response).unwrap();
 945
 946    cx.run_until_parked();
 947
 948    ep_store.update(cx, |ep_store, cx| {
 949        assert!(
 950            ep_store
 951                .prediction_at(&buffer, None, &project, cx)
 952                .is_none()
 953        );
 954    });
 955
 956    // prediction is reported as rejected
 957    let (reject_request, _) = requests.reject.next().await.unwrap();
 958
 959    assert_eq!(
 960        &reject_request.rejections,
 961        &[EditPredictionRejection {
 962            request_id: id,
 963            reason: EditPredictionRejectReason::InterpolatedEmpty,
 964            was_shown: false,
 965            model_version: None,
 966        }]
 967    );
 968}
 969
 970const SIMPLE_DIFF: &str = indoc! { r"
 971    --- a/root/foo.md
 972    +++ b/root/foo.md
 973    @@ ... @@
 974     Hello!
 975    -How
 976    +How are you?
 977     Bye
 978"};
 979
 980#[gpui::test]
 981async fn test_replace_current(cx: &mut TestAppContext) {
 982    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 983    let fs = FakeFs::new(cx.executor());
 984    fs.insert_tree(
 985        "/root",
 986        json!({
 987            "foo.md":  "Hello!\nHow\nBye\n"
 988        }),
 989    )
 990    .await;
 991    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 992
 993    let buffer = project
 994        .update(cx, |project, cx| {
 995            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 996            project.open_buffer(path, cx)
 997        })
 998        .await
 999        .unwrap();
1000    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1001    let position = snapshot.anchor_before(language::Point::new(1, 3));
1002
1003    ep_store.update(cx, |ep_store, cx| {
1004        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1005    });
1006
1007    let (request, respond_tx) = requests.predict.next().await.unwrap();
1008    let first_response = model_response(&request, SIMPLE_DIFF);
1009    let first_id = first_response.request_id.clone();
1010    respond_tx.send(first_response).unwrap();
1011
1012    cx.run_until_parked();
1013
1014    ep_store.update(cx, |ep_store, cx| {
1015        assert_eq!(
1016            ep_store
1017                .prediction_at(&buffer, None, &project, cx)
1018                .unwrap()
1019                .id
1020                .0,
1021            first_id
1022        );
1023    });
1024
1025    // a second request is triggered
1026    ep_store.update(cx, |ep_store, cx| {
1027        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1028    });
1029
1030    let (request, respond_tx) = requests.predict.next().await.unwrap();
1031    let second_response = model_response(&request, SIMPLE_DIFF);
1032    let second_id = second_response.request_id.clone();
1033    respond_tx.send(second_response).unwrap();
1034
1035    cx.run_until_parked();
1036
1037    ep_store.update(cx, |ep_store, cx| {
1038        // second replaces first
1039        assert_eq!(
1040            ep_store
1041                .prediction_at(&buffer, None, &project, cx)
1042                .unwrap()
1043                .id
1044                .0,
1045            second_id
1046        );
1047    });
1048
1049    // first is reported as replaced
1050    let (reject_request, _) = requests.reject.next().await.unwrap();
1051
1052    assert_eq!(
1053        &reject_request.rejections,
1054        &[EditPredictionRejection {
1055            request_id: first_id,
1056            reason: EditPredictionRejectReason::Replaced,
1057            was_shown: false,
1058            model_version: None,
1059        }]
1060    );
1061}
1062
1063#[gpui::test]
1064async fn test_current_preferred(cx: &mut TestAppContext) {
1065    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1066    let fs = FakeFs::new(cx.executor());
1067    fs.insert_tree(
1068        "/root",
1069        json!({
1070            "foo.md":  "Hello!\nHow\nBye\n"
1071        }),
1072    )
1073    .await;
1074    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1075
1076    let buffer = project
1077        .update(cx, |project, cx| {
1078            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1079            project.open_buffer(path, cx)
1080        })
1081        .await
1082        .unwrap();
1083    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1084    let position = snapshot.anchor_before(language::Point::new(1, 3));
1085
1086    ep_store.update(cx, |ep_store, cx| {
1087        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1088    });
1089
1090    let (request, respond_tx) = requests.predict.next().await.unwrap();
1091    let first_response = model_response(&request, SIMPLE_DIFF);
1092    let first_id = first_response.request_id.clone();
1093    respond_tx.send(first_response).unwrap();
1094
1095    cx.run_until_parked();
1096
1097    ep_store.update(cx, |ep_store, cx| {
1098        assert_eq!(
1099            ep_store
1100                .prediction_at(&buffer, None, &project, cx)
1101                .unwrap()
1102                .id
1103                .0,
1104            first_id
1105        );
1106    });
1107
1108    // a second request is triggered
1109    ep_store.update(cx, |ep_store, cx| {
1110        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1111    });
1112
1113    let (request, respond_tx) = requests.predict.next().await.unwrap();
1114    // worse than current prediction
1115    let second_response = model_response(
1116        &request,
1117        indoc! { r"
1118            --- a/root/foo.md
1119            +++ b/root/foo.md
1120            @@ ... @@
1121             Hello!
1122            -How
1123            +How are
1124             Bye
1125        "},
1126    );
1127    let second_id = second_response.request_id.clone();
1128    respond_tx.send(second_response).unwrap();
1129
1130    cx.run_until_parked();
1131
1132    ep_store.update(cx, |ep_store, cx| {
1133        // first is preferred over second
1134        assert_eq!(
1135            ep_store
1136                .prediction_at(&buffer, None, &project, cx)
1137                .unwrap()
1138                .id
1139                .0,
1140            first_id
1141        );
1142    });
1143
1144    // second is reported as rejected
1145    let (reject_request, _) = requests.reject.next().await.unwrap();
1146
1147    assert_eq!(
1148        &reject_request.rejections,
1149        &[EditPredictionRejection {
1150            request_id: second_id,
1151            reason: EditPredictionRejectReason::CurrentPreferred,
1152            was_shown: false,
1153            model_version: None,
1154        }]
1155    );
1156}
1157
1158#[gpui::test]
1159async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1160    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1161    let fs = FakeFs::new(cx.executor());
1162    fs.insert_tree(
1163        "/root",
1164        json!({
1165            "foo.md":  "Hello!\nHow\nBye\n"
1166        }),
1167    )
1168    .await;
1169    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1170
1171    let buffer = project
1172        .update(cx, |project, cx| {
1173            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1174            project.open_buffer(path, cx)
1175        })
1176        .await
1177        .unwrap();
1178    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1179    let position = snapshot.anchor_before(language::Point::new(1, 3));
1180
1181    // start two refresh tasks
1182    ep_store.update(cx, |ep_store, cx| {
1183        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1184    });
1185
1186    let (request1, respond_first) = requests.predict.next().await.unwrap();
1187
1188    ep_store.update(cx, |ep_store, cx| {
1189        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1190    });
1191
1192    let (request, respond_second) = requests.predict.next().await.unwrap();
1193
1194    // wait for throttle
1195    cx.run_until_parked();
1196
1197    // second responds first
1198    let second_response = model_response(&request, SIMPLE_DIFF);
1199    let second_id = second_response.request_id.clone();
1200    respond_second.send(second_response).unwrap();
1201
1202    cx.run_until_parked();
1203
1204    ep_store.update(cx, |ep_store, cx| {
1205        // current prediction is second
1206        assert_eq!(
1207            ep_store
1208                .prediction_at(&buffer, None, &project, cx)
1209                .unwrap()
1210                .id
1211                .0,
1212            second_id
1213        );
1214    });
1215
1216    let first_response = model_response(&request1, SIMPLE_DIFF);
1217    let first_id = first_response.request_id.clone();
1218    respond_first.send(first_response).unwrap();
1219
1220    cx.run_until_parked();
1221
1222    ep_store.update(cx, |ep_store, cx| {
1223        // current prediction is still second, since first was cancelled
1224        assert_eq!(
1225            ep_store
1226                .prediction_at(&buffer, None, &project, cx)
1227                .unwrap()
1228                .id
1229                .0,
1230            second_id
1231        );
1232    });
1233
1234    // first is reported as rejected
1235    let (reject_request, _) = requests.reject.next().await.unwrap();
1236
1237    cx.run_until_parked();
1238
1239    assert_eq!(
1240        &reject_request.rejections,
1241        &[EditPredictionRejection {
1242            request_id: first_id,
1243            reason: EditPredictionRejectReason::Canceled,
1244            was_shown: false,
1245            model_version: None,
1246        }]
1247    );
1248}
1249
1250#[gpui::test]
1251async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1252    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1253    let fs = FakeFs::new(cx.executor());
1254    fs.insert_tree(
1255        "/root",
1256        json!({
1257            "foo.md":  "Hello!\nHow\nBye\n"
1258        }),
1259    )
1260    .await;
1261    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1262
1263    let buffer = project
1264        .update(cx, |project, cx| {
1265            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1266            project.open_buffer(path, cx)
1267        })
1268        .await
1269        .unwrap();
1270    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1271    let position = snapshot.anchor_before(language::Point::new(1, 3));
1272
1273    // start two refresh tasks
1274    ep_store.update(cx, |ep_store, cx| {
1275        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1276    });
1277
1278    let (request1, respond_first) = requests.predict.next().await.unwrap();
1279
1280    ep_store.update(cx, |ep_store, cx| {
1281        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1282    });
1283
1284    let (request2, respond_second) = requests.predict.next().await.unwrap();
1285
1286    // wait for throttle, so requests are sent
1287    cx.run_until_parked();
1288
1289    ep_store.update(cx, |ep_store, cx| {
1290        // start a third request
1291        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1292
1293        // 2 are pending, so 2nd is cancelled
1294        assert_eq!(
1295            ep_store
1296                .get_or_init_project(&project, cx)
1297                .cancelled_predictions
1298                .iter()
1299                .copied()
1300                .collect::<Vec<_>>(),
1301            [1]
1302        );
1303    });
1304
1305    // wait for throttle
1306    cx.run_until_parked();
1307
1308    let (request3, respond_third) = requests.predict.next().await.unwrap();
1309
1310    let first_response = model_response(&request1, SIMPLE_DIFF);
1311    let first_id = first_response.request_id.clone();
1312    respond_first.send(first_response).unwrap();
1313
1314    cx.run_until_parked();
1315
1316    ep_store.update(cx, |ep_store, cx| {
1317        // current prediction is first
1318        assert_eq!(
1319            ep_store
1320                .prediction_at(&buffer, None, &project, cx)
1321                .unwrap()
1322                .id
1323                .0,
1324            first_id
1325        );
1326    });
1327
1328    let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1329    let cancelled_id = cancelled_response.request_id.clone();
1330    respond_second.send(cancelled_response).unwrap();
1331
1332    cx.run_until_parked();
1333
1334    ep_store.update(cx, |ep_store, cx| {
1335        // current prediction is still first, since second was cancelled
1336        assert_eq!(
1337            ep_store
1338                .prediction_at(&buffer, None, &project, cx)
1339                .unwrap()
1340                .id
1341                .0,
1342            first_id
1343        );
1344    });
1345
1346    let third_response = model_response(&request3, SIMPLE_DIFF);
1347    let third_response_id = third_response.request_id.clone();
1348    respond_third.send(third_response).unwrap();
1349
1350    cx.run_until_parked();
1351
1352    ep_store.update(cx, |ep_store, cx| {
1353        // third completes and replaces first
1354        assert_eq!(
1355            ep_store
1356                .prediction_at(&buffer, None, &project, cx)
1357                .unwrap()
1358                .id
1359                .0,
1360            third_response_id
1361        );
1362    });
1363
1364    // second is reported as rejected
1365    let (reject_request, _) = requests.reject.next().await.unwrap();
1366
1367    cx.run_until_parked();
1368
1369    assert_eq!(
1370        &reject_request.rejections,
1371        &[
1372            EditPredictionRejection {
1373                request_id: cancelled_id,
1374                reason: EditPredictionRejectReason::Canceled,
1375                was_shown: false,
1376                model_version: None,
1377            },
1378            EditPredictionRejection {
1379                request_id: first_id,
1380                reason: EditPredictionRejectReason::Replaced,
1381                was_shown: false,
1382                model_version: None,
1383            }
1384        ]
1385    );
1386}
1387
1388#[gpui::test]
1389async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
1390    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1391
1392    let fs = FakeFs::new(cx.executor());
1393    fs.insert_tree(
1394        "/root",
1395        json!({
1396            "foo.md":  "Hello!\nHow\nBye\n",
1397            "bar.md": "Hola!\nComo\nAdios\n"
1398        }),
1399    )
1400    .await;
1401    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1402
1403    let buffer = project
1404        .update(cx, |project, cx| {
1405            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1406            project.set_active_path(Some(path.clone()), cx);
1407            project.open_buffer(path, cx)
1408        })
1409        .await
1410        .unwrap();
1411    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1412    let position = snapshot.anchor_before(language::Point::new(1, 3));
1413
1414    ep_store.update(cx, |ep_store, cx| {
1415        ep_store.register_project(&project, cx);
1416        ep_store.register_buffer(&buffer, &project, cx);
1417    });
1418
1419    // First edit request - no prior edit, so not throttled.
1420    ep_store.update(cx, |ep_store, cx| {
1421        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1422    });
1423    let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
1424    edit_response_tx.send(empty_response()).unwrap();
1425    cx.run_until_parked();
1426
1427    let diagnostic = lsp::Diagnostic {
1428        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1429        severity: Some(lsp::DiagnosticSeverity::ERROR),
1430        message: "Sentence is incomplete".to_string(),
1431        ..Default::default()
1432    };
1433
1434    // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
1435    project.update(cx, |project, cx| {
1436        project.lsp_store().update(cx, |lsp_store, cx| {
1437            lsp_store
1438                .update_diagnostics(
1439                    LanguageServerId(0),
1440                    lsp::PublishDiagnosticsParams {
1441                        uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
1442                        diagnostics: vec![diagnostic],
1443                        version: None,
1444                    },
1445                    None,
1446                    language::DiagnosticSourceKind::Pushed,
1447                    &[],
1448                    cx,
1449                )
1450                .unwrap();
1451        });
1452    });
1453    let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
1454    jump_response_tx.send(empty_response()).unwrap();
1455    cx.run_until_parked();
1456
1457    // Second edit request - should be throttled by the first edit.
1458    ep_store.update(cx, |ep_store, cx| {
1459        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1460    });
1461    assert_no_predict_request_ready(&mut requests.predict);
1462
1463    // Second jump request - should be throttled by the first jump.
1464    ep_store.update(cx, |ep_store, cx| {
1465        ep_store.refresh_prediction_from_diagnostics(
1466            project.clone(),
1467            DiagnosticSearchScope::Global,
1468            cx,
1469        );
1470    });
1471    assert_no_predict_request_ready(&mut requests.predict);
1472
1473    // Wait for both throttles to expire.
1474    cx.background_executor
1475        .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
1476    cx.background_executor.run_until_parked();
1477    cx.run_until_parked();
1478
1479    // Both requests should now go through.
1480    let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
1481    response_tx_1.send(empty_response()).unwrap();
1482    cx.run_until_parked();
1483
1484    let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
1485    response_tx_2.send(empty_response()).unwrap();
1486    cx.run_until_parked();
1487}
1488
1489#[gpui::test]
1490async fn test_same_frame_duplicate_requests_deduplicated(cx: &mut TestAppContext) {
1491    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1492    let fs = FakeFs::new(cx.executor());
1493    fs.insert_tree(
1494        "/root",
1495        json!({
1496            "foo.md":  "Hello!\nHow\nBye\n"
1497        }),
1498    )
1499    .await;
1500    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1501
1502    let buffer = project
1503        .update(cx, |project, cx| {
1504            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1505            project.open_buffer(path, cx)
1506        })
1507        .await
1508        .unwrap();
1509    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1510    let position = snapshot.anchor_before(language::Point::new(1, 3));
1511
1512    // Enqueue two refresh calls in the same synchronous frame (no yielding).
1513    // Both `cx.spawn` tasks are created before either executes, so they both
1514    // capture the same `proceed_count_at_enqueue`. Only the first task should
1515    // pass the deduplication gate; the second should be skipped.
1516    ep_store.update(cx, |ep_store, cx| {
1517        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1518        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1519    });
1520
1521    // Let both spawned tasks run to completion (including any throttle waits).
1522    cx.run_until_parked();
1523
1524    // Exactly one prediction request should have been sent.
1525    let (request, respond_tx) = requests.predict.next().await.unwrap();
1526    respond_tx
1527        .send(model_response(&request, SIMPLE_DIFF))
1528        .unwrap();
1529    cx.run_until_parked();
1530
1531    // No second request should be pending.
1532    assert_no_predict_request_ready(&mut requests.predict);
1533}
1534
1535#[gpui::test]
1536async fn test_rejections_flushing(cx: &mut TestAppContext) {
1537    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1538
1539    ep_store.update(cx, |ep_store, cx| {
1540        ep_store.reject_prediction(
1541            EditPredictionId("test-1".into()),
1542            EditPredictionRejectReason::Discarded,
1543            false,
1544            None,
1545            cx,
1546        );
1547        ep_store.reject_prediction(
1548            EditPredictionId("test-2".into()),
1549            EditPredictionRejectReason::Canceled,
1550            true,
1551            None,
1552            cx,
1553        );
1554    });
1555
1556    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1557    cx.run_until_parked();
1558
1559    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1560    respond_tx.send(()).unwrap();
1561
1562    // batched
1563    assert_eq!(reject_request.rejections.len(), 2);
1564    assert_eq!(
1565        reject_request.rejections[0],
1566        EditPredictionRejection {
1567            request_id: "test-1".to_string(),
1568            reason: EditPredictionRejectReason::Discarded,
1569            was_shown: false,
1570            model_version: None,
1571        }
1572    );
1573    assert_eq!(
1574        reject_request.rejections[1],
1575        EditPredictionRejection {
1576            request_id: "test-2".to_string(),
1577            reason: EditPredictionRejectReason::Canceled,
1578            was_shown: true,
1579            model_version: None,
1580        }
1581    );
1582
1583    // Reaching batch size limit sends without debounce
1584    ep_store.update(cx, |ep_store, cx| {
1585        for i in 0..70 {
1586            ep_store.reject_prediction(
1587                EditPredictionId(format!("batch-{}", i).into()),
1588                EditPredictionRejectReason::Discarded,
1589                false,
1590                None,
1591                cx,
1592            );
1593        }
1594    });
1595
1596    // First MAX/2 items are sent immediately
1597    cx.run_until_parked();
1598    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1599    respond_tx.send(()).unwrap();
1600
1601    assert_eq!(reject_request.rejections.len(), 50);
1602    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1603    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1604
1605    // Remaining items are debounced with the next batch
1606    cx.executor().advance_clock(Duration::from_secs(15));
1607    cx.run_until_parked();
1608
1609    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1610    respond_tx.send(()).unwrap();
1611
1612    assert_eq!(reject_request.rejections.len(), 20);
1613    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1614    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1615
1616    // Request failure
1617    ep_store.update(cx, |ep_store, cx| {
1618        ep_store.reject_prediction(
1619            EditPredictionId("retry-1".into()),
1620            EditPredictionRejectReason::Discarded,
1621            false,
1622            None,
1623            cx,
1624        );
1625    });
1626
1627    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1628    cx.run_until_parked();
1629
1630    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1631    assert_eq!(reject_request.rejections.len(), 1);
1632    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1633    // Simulate failure
1634    drop(_respond_tx);
1635
1636    // Add another rejection
1637    ep_store.update(cx, |ep_store, cx| {
1638        ep_store.reject_prediction(
1639            EditPredictionId("retry-2".into()),
1640            EditPredictionRejectReason::Discarded,
1641            false,
1642            None,
1643            cx,
1644        );
1645    });
1646
1647    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1648    cx.run_until_parked();
1649
1650    // Retry should include both the failed item and the new one
1651    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1652    respond_tx.send(()).unwrap();
1653
1654    assert_eq!(reject_request.rejections.len(), 2);
1655    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1656    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1657}
1658
1659// Skipped until we start including diagnostics in prompt
1660// #[gpui::test]
1661// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1662//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1663//     let fs = FakeFs::new(cx.executor());
1664//     fs.insert_tree(
1665//         "/root",
1666//         json!({
1667//             "foo.md": "Hello!\nBye"
1668//         }),
1669//     )
1670//     .await;
1671//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1672
1673//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1674//     let diagnostic = lsp::Diagnostic {
1675//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1676//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1677//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1678//         ..Default::default()
1679//     };
1680
1681//     project.update(cx, |project, cx| {
1682//         project.lsp_store().update(cx, |lsp_store, cx| {
1683//             // Create some diagnostics
1684//             lsp_store
1685//                 .update_diagnostics(
1686//                     LanguageServerId(0),
1687//                     lsp::PublishDiagnosticsParams {
1688//                         uri: path_to_buffer_uri.clone(),
1689//                         diagnostics: vec![diagnostic],
1690//                         version: None,
1691//                     },
1692//                     None,
1693//                     language::DiagnosticSourceKind::Pushed,
1694//                     &[],
1695//                     cx,
1696//                 )
1697//                 .unwrap();
1698//         });
1699//     });
1700
1701//     let buffer = project
1702//         .update(cx, |project, cx| {
1703//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1704//             project.open_buffer(path, cx)
1705//         })
1706//         .await
1707//         .unwrap();
1708
1709//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1710//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1711
1712//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1713//         ep_store.request_prediction(&project, &buffer, position, cx)
1714//     });
1715
1716//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1717
1718//     assert_eq!(request.diagnostic_groups.len(), 1);
1719//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1720//         .unwrap();
1721//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1722//     assert_eq!(
1723//         value,
1724//         json!({
1725//             "entries": [{
1726//                 "range": {
1727//                     "start": 8,
1728//                     "end": 10
1729//                 },
1730//                 "diagnostic": {
1731//                     "source": null,
1732//                     "code": null,
1733//                     "code_description": null,
1734//                     "severity": 1,
1735//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1736//                     "markdown": null,
1737//                     "group_id": 0,
1738//                     "is_primary": true,
1739//                     "is_disk_based": false,
1740//                     "is_unnecessary": false,
1741//                     "source_kind": "Pushed",
1742//                     "data": null,
1743//                     "underline": true
1744//                 }
1745//             }],
1746//             "primary_ix": 0
1747//         })
1748//     );
1749// }
1750
1751// Generate a model response that would apply the given diff to the active file.
1752fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1753    let editable_range =
1754        zeta_prompt::excerpt_range_for_format(Default::default(), &request.input.excerpt_ranges).1;
1755    let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string();
1756    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1757
1758    PredictEditsV3Response {
1759        request_id: Uuid::new_v4().to_string(),
1760        editable_range,
1761        output: new_excerpt,
1762        model_version: None,
1763    }
1764}
1765
1766fn empty_response() -> PredictEditsV3Response {
1767    PredictEditsV3Response {
1768        request_id: Uuid::new_v4().to_string(),
1769        editable_range: 0..0,
1770        output: String::new(),
1771        model_version: None,
1772    }
1773}
1774
1775fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1776    zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1777}
1778
1779fn assert_no_predict_request_ready(
1780    requests: &mut mpsc::UnboundedReceiver<(
1781        PredictEditsV3Request,
1782        oneshot::Sender<PredictEditsV3Response>,
1783    )>,
1784) {
1785    if requests.next().now_or_never().flatten().is_some() {
1786        panic!("Unexpected prediction request while throttled.");
1787    }
1788}
1789
1790struct RequestChannels {
1791    predict: mpsc::UnboundedReceiver<(
1792        PredictEditsV3Request,
1793        oneshot::Sender<PredictEditsV3Response>,
1794    )>,
1795    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1796}
1797
1798fn init_test_with_fake_client(
1799    cx: &mut TestAppContext,
1800) -> (Entity<EditPredictionStore>, RequestChannels) {
1801    cx.update(move |cx| {
1802        let settings_store = SettingsStore::test(cx);
1803        cx.set_global(settings_store);
1804        zlog::init_test();
1805
1806        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1807        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1808
1809        let http_client = FakeHttpClient::create({
1810            move |req| {
1811                let uri = req.uri().path().to_string();
1812                let mut body = req.into_body();
1813                let predict_req_tx = predict_req_tx.clone();
1814                let reject_req_tx = reject_req_tx.clone();
1815                async move {
1816                    let resp = match uri.as_str() {
1817                        "/client/llm_tokens" => serde_json::to_string(&json!({
1818                            "token": "test"
1819                        }))
1820                        .unwrap(),
1821                        "/predict_edits/v3" => {
1822                            let mut buf = Vec::new();
1823                            body.read_to_end(&mut buf).await.ok();
1824                            let decompressed = zstd::decode_all(&buf[..]).unwrap();
1825                            let req = serde_json::from_slice(&decompressed).unwrap();
1826
1827                            let (res_tx, res_rx) = oneshot::channel();
1828                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1829                            serde_json::to_string(&res_rx.await?).unwrap()
1830                        }
1831                        "/predict_edits/reject" => {
1832                            let mut buf = Vec::new();
1833                            body.read_to_end(&mut buf).await.ok();
1834                            let req = serde_json::from_slice(&buf).unwrap();
1835
1836                            let (res_tx, res_rx) = oneshot::channel();
1837                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1838                            serde_json::to_string(&res_rx.await?).unwrap()
1839                        }
1840                        _ => {
1841                            panic!("Unexpected path: {}", uri)
1842                        }
1843                    };
1844
1845                    Ok(Response::builder().body(resp.into()).unwrap())
1846                }
1847            }
1848        });
1849
1850        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1851        client.cloud_client().set_credentials(1, "test".into());
1852
1853        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1854        language_model::init(user_store.clone(), client.clone(), cx);
1855        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1856
1857        (
1858            ep_store,
1859            RequestChannels {
1860                predict: predict_req_rx,
1861                reject: reject_req_rx,
1862            },
1863        )
1864    })
1865}
1866
1867#[gpui::test]
1868async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1869    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1870    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1871        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1872    });
1873
1874    let edit_preview = cx
1875        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1876        .await;
1877
1878    let prediction = EditPrediction {
1879        edits,
1880        cursor_position: None,
1881        edit_preview,
1882        buffer: buffer.clone(),
1883        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1884        id: EditPredictionId("the-id".into()),
1885        inputs: ZetaPromptInput {
1886            events: Default::default(),
1887            related_files: Default::default(),
1888            cursor_path: Path::new("").into(),
1889            cursor_excerpt: "".into(),
1890            cursor_offset_in_excerpt: 0,
1891            excerpt_start_row: None,
1892            excerpt_ranges: Default::default(),
1893            experiment: None,
1894            in_open_source_repo: false,
1895            can_collect_data: false,
1896            repo_url: None,
1897        },
1898        buffer_snapshotted_at: Instant::now(),
1899        response_received_at: Instant::now(),
1900        model_version: None,
1901    };
1902
1903    cx.update(|cx| {
1904        assert_eq!(
1905            from_completion_edits(
1906                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1907                &buffer,
1908                cx
1909            ),
1910            vec![(2..5, "REM".into()), (9..11, "".into())]
1911        );
1912
1913        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1914        assert_eq!(
1915            from_completion_edits(
1916                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1917                &buffer,
1918                cx
1919            ),
1920            vec![(2..2, "REM".into()), (6..8, "".into())]
1921        );
1922
1923        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1924        assert_eq!(
1925            from_completion_edits(
1926                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1927                &buffer,
1928                cx
1929            ),
1930            vec![(2..5, "REM".into()), (9..11, "".into())]
1931        );
1932
1933        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1934        assert_eq!(
1935            from_completion_edits(
1936                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1937                &buffer,
1938                cx
1939            ),
1940            vec![(3..3, "EM".into()), (7..9, "".into())]
1941        );
1942
1943        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1944        assert_eq!(
1945            from_completion_edits(
1946                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1947                &buffer,
1948                cx
1949            ),
1950            vec![(4..4, "M".into()), (8..10, "".into())]
1951        );
1952
1953        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1954        assert_eq!(
1955            from_completion_edits(
1956                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1957                &buffer,
1958                cx
1959            ),
1960            vec![(9..11, "".into())]
1961        );
1962
1963        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1964        assert_eq!(
1965            from_completion_edits(
1966                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1967                &buffer,
1968                cx
1969            ),
1970            vec![(4..4, "M".into()), (8..10, "".into())]
1971        );
1972
1973        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1974        assert_eq!(
1975            from_completion_edits(
1976                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1977                &buffer,
1978                cx
1979            ),
1980            vec![(4..4, "M".into())]
1981        );
1982
1983        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1984        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1985    })
1986}
1987
1988#[gpui::test]
1989async fn test_clean_up_diff(cx: &mut TestAppContext) {
1990    init_test(cx);
1991
1992    assert_eq!(
1993        apply_edit_prediction(
1994            indoc! {"
1995                    fn main() {
1996                        let word_1 = \"lorem\";
1997                        let range = word.len()..word.len();
1998                    }
1999                "},
2000            indoc! {"
2001                    fn main() {
2002                        let word_1 = \"lorem\";
2003                        let range = word_1.len()..word_1.len();
2004                    }
2005                "},
2006            cx,
2007        )
2008        .await,
2009        indoc! {"
2010                fn main() {
2011                    let word_1 = \"lorem\";
2012                    let range = word_1.len()..word_1.len();
2013                }
2014            "},
2015    );
2016
2017    assert_eq!(
2018        apply_edit_prediction(
2019            indoc! {"
2020                    fn main() {
2021                        let story = \"the quick\"
2022                    }
2023                "},
2024            indoc! {"
2025                    fn main() {
2026                        let story = \"the quick brown fox jumps over the lazy dog\";
2027                    }
2028                "},
2029            cx,
2030        )
2031        .await,
2032        indoc! {"
2033                fn main() {
2034                    let story = \"the quick brown fox jumps over the lazy dog\";
2035                }
2036            "},
2037    );
2038}
2039
2040#[gpui::test]
2041async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2042    init_test(cx);
2043
2044    let buffer_content = "lorem\n";
2045    let completion_response = "lorem\nipsum\n";
2046
2047    assert_eq!(
2048        apply_edit_prediction(buffer_content, completion_response, cx).await,
2049        "lorem\nipsum\n"
2050    );
2051}
2052
2053#[gpui::test]
2054async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
2055    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
2056    // When the buffer ends without a trailing newline, but the model returns output
2057    // with a trailing newline, zeta2 should normalize both sides before diffing
2058    // so no spurious newline is inserted.
2059    let (ep_store, mut requests) = init_test_with_fake_client(cx);
2060    let fs = FakeFs::new(cx.executor());
2061
2062    // Single line buffer with no trailing newline
2063    fs.insert_tree(
2064        "/root",
2065        json!({
2066            "foo.txt": "hello"
2067        }),
2068    )
2069    .await;
2070    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2071
2072    let buffer = project
2073        .update(cx, |project, cx| {
2074            let path = project
2075                .find_project_path(path!("root/foo.txt"), cx)
2076                .unwrap();
2077            project.open_buffer(path, cx)
2078        })
2079        .await
2080        .unwrap();
2081
2082    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2083    let position = snapshot.anchor_before(language::Point::new(0, 5));
2084
2085    ep_store.update(cx, |ep_store, cx| {
2086        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2087    });
2088
2089    let (request, respond_tx) = requests.predict.next().await.unwrap();
2090
2091    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
2092    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
2093    let excerpt_length = request.input.cursor_excerpt.len();
2094    let response = PredictEditsV3Response {
2095        request_id: Uuid::new_v4().to_string(),
2096        output: "hello world\n".to_string(),
2097        editable_range: 0..excerpt_length,
2098        model_version: None,
2099    };
2100    respond_tx.send(response).unwrap();
2101
2102    cx.run_until_parked();
2103
2104    // The prediction should insert " world" without adding a newline
2105    ep_store.update(cx, |ep_store, cx| {
2106        let prediction = ep_store
2107            .prediction_at(&buffer, None, &project, cx)
2108            .expect("should have prediction");
2109        let edits: Vec<_> = prediction
2110            .edits
2111            .iter()
2112            .map(|(range, text)| {
2113                let snapshot = buffer.read(cx).snapshot();
2114                (range.to_offset(&snapshot), text.clone())
2115            })
2116            .collect();
2117        assert_eq!(edits, vec![(5..5, " world".into())]);
2118    });
2119}
2120
2121fn init_test(cx: &mut TestAppContext) {
2122    cx.update(|cx| {
2123        let settings_store = SettingsStore::test(cx);
2124        cx.set_global(settings_store);
2125    });
2126}
2127
2128async fn apply_edit_prediction(
2129    buffer_content: &str,
2130    completion_response: &str,
2131    cx: &mut TestAppContext,
2132) -> String {
2133    let fs = project::FakeFs::new(cx.executor());
2134    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2135    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2136    let (ep_store, response) = make_test_ep_store(&project, cx).await;
2137    *response.lock() = completion_response.to_string();
2138    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2139    buffer.update(cx, |buffer, cx| {
2140        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2141    });
2142    buffer.read_with(cx, |buffer, _| buffer.text())
2143}
2144
2145async fn run_edit_prediction(
2146    buffer: &Entity<Buffer>,
2147    project: &Entity<Project>,
2148    ep_store: &Entity<EditPredictionStore>,
2149    cx: &mut TestAppContext,
2150) -> EditPrediction {
2151    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2152    ep_store.update(cx, |ep_store, cx| {
2153        ep_store.register_buffer(buffer, &project, cx)
2154    });
2155    cx.background_executor.run_until_parked();
2156    let prediction_task = ep_store.update(cx, |ep_store, cx| {
2157        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2158    });
2159    prediction_task.await.unwrap().unwrap().prediction.unwrap()
2160}
2161
2162async fn make_test_ep_store(
2163    project: &Entity<Project>,
2164    cx: &mut TestAppContext,
2165) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
2166    let default_response = "hello world\n".to_string();
2167    let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
2168    let http_client = FakeHttpClient::create({
2169        let completion_response = completion_response.clone();
2170        let mut next_request_id = 0;
2171        move |req| {
2172            let completion_response = completion_response.clone();
2173            let method = req.method().clone();
2174            let uri = req.uri().path().to_string();
2175            let mut body = req.into_body();
2176            async move {
2177                match (method, uri.as_str()) {
2178                    (Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2179                        .status(200)
2180                        .body(
2181                            serde_json::to_string(&CreateLlmTokenResponse {
2182                                token: LlmToken("the-llm-token".to_string()),
2183                            })
2184                            .unwrap()
2185                            .into(),
2186                        )
2187                        .unwrap()),
2188                    (Method::POST, "/predict_edits/v3") => {
2189                        let mut buf = Vec::new();
2190                        body.read_to_end(&mut buf).await.ok();
2191                        let decompressed = zstd::decode_all(&buf[..]).unwrap();
2192                        let req: PredictEditsV3Request =
2193                            serde_json::from_slice(&decompressed).unwrap();
2194
2195                        next_request_id += 1;
2196                        Ok(http_client::Response::builder()
2197                            .status(200)
2198                            .body(
2199                                serde_json::to_string(&PredictEditsV3Response {
2200                                    request_id: format!("request-{next_request_id}"),
2201                                    editable_range: 0..req.input.cursor_excerpt.len(),
2202                                    output: completion_response.lock().clone(),
2203                                    model_version: None,
2204                                })
2205                                .unwrap()
2206                                .into(),
2207                            )
2208                            .unwrap())
2209                    }
2210                    _ => Ok(http_client::Response::builder()
2211                        .status(404)
2212                        .body("Not Found".to_string().into())
2213                        .unwrap()),
2214                }
2215            }
2216        }
2217    });
2218
2219    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2220    let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
2221    cx.update(|cx| {
2222        RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
2223    });
2224    let _server = FakeServer::for_client(42, &client, cx).await;
2225
2226    let ep_store = cx.new(|cx| {
2227        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2228        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2229
2230        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2231        for worktree in worktrees {
2232            let worktree_id = worktree.read(cx).id();
2233            ep_store
2234                .get_or_init_project(project, cx)
2235                .license_detection_watchers
2236                .entry(worktree_id)
2237                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2238        }
2239
2240        ep_store
2241    });
2242
2243    (ep_store, completion_response)
2244}
2245
2246fn to_completion_edits(
2247    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2248    buffer: &Entity<Buffer>,
2249    cx: &App,
2250) -> Vec<(Range<Anchor>, Arc<str>)> {
2251    let buffer = buffer.read(cx);
2252    iterator
2253        .into_iter()
2254        .map(|(range, text)| {
2255            (
2256                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2257                text,
2258            )
2259        })
2260        .collect()
2261}
2262
2263fn from_completion_edits(
2264    editor_edits: &[(Range<Anchor>, Arc<str>)],
2265    buffer: &Entity<Buffer>,
2266    cx: &App,
2267) -> Vec<(Range<usize>, Arc<str>)> {
2268    let buffer = buffer.read(cx);
2269    editor_edits
2270        .iter()
2271        .map(|(range, text)| {
2272            (
2273                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2274                text.clone(),
2275            )
2276        })
2277        .collect()
2278}
2279
2280#[gpui::test]
2281async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2282    init_test(cx);
2283
2284    let fs = FakeFs::new(cx.executor());
2285    fs.insert_tree(
2286        "/project",
2287        serde_json::json!({
2288            "main.rs": "fn main() {\n    \n}\n"
2289        }),
2290    )
2291    .await;
2292
2293    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2294
2295    let http_client = FakeHttpClient::create(|_req| async move {
2296        Ok(gpui::http_client::Response::builder()
2297            .status(401)
2298            .body("Unauthorized".into())
2299            .unwrap())
2300    });
2301
2302    let client =
2303        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2304    let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
2305    cx.update(|cx| {
2306        language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
2307    });
2308
2309    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2310
2311    let buffer = project
2312        .update(cx, |project, cx| {
2313            let path = project
2314                .find_project_path(path!("/project/main.rs"), cx)
2315                .unwrap();
2316            project.open_buffer(path, cx)
2317        })
2318        .await
2319        .unwrap();
2320
2321    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2322    ep_store.update(cx, |ep_store, cx| {
2323        ep_store.register_buffer(&buffer, &project, cx)
2324    });
2325    cx.background_executor.run_until_parked();
2326
2327    let completion_task = ep_store.update(cx, |ep_store, cx| {
2328        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2329        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2330    });
2331
2332    let result = completion_task.await;
2333    assert!(
2334        result.is_err(),
2335        "Without authentication and without custom URL, prediction should fail"
2336    );
2337}
2338
2339#[gpui::test]
2340fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2341    let buffer = cx.new(|cx| {
2342        Buffer::local(
2343            indoc! {"
2344                zero
2345                one
2346                two
2347                three
2348                four
2349                five
2350                six
2351                seven
2352                eight
2353                nine
2354                ten
2355                eleven
2356                twelve
2357                thirteen
2358                fourteen
2359                fifteen
2360                sixteen
2361                seventeen
2362                eighteen
2363                nineteen
2364                twenty
2365                twenty-one
2366                twenty-two
2367                twenty-three
2368                twenty-four
2369            "},
2370            cx,
2371        )
2372    });
2373
2374    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2375
2376    buffer.update(cx, |buffer, cx| {
2377        let point = Point::new(12, 0);
2378        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2379        let point = Point::new(8, 0);
2380        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2381    });
2382
2383    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2384
2385    let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2386
2387    assert_eq!(
2388        diff,
2389        indoc! {"
2390            @@ -6,10 +6,12 @@
2391             five
2392             six
2393             seven
2394            +FIRST INSERTION
2395             eight
2396             nine
2397             ten
2398             eleven
2399            +SECOND INSERTION
2400             twelve
2401             thirteen
2402             fourteen
2403            "}
2404    );
2405}
2406
2407#[gpui::test]
2408async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
2409    fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
2410        let collab_replica = clock::ReplicaId::new(10);
2411        let anchor = buffer.read_with(cx, |buffer, _| {
2412            buffer.snapshot().anchor_before(Point::new(row, 0))
2413        });
2414        let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
2415            id: 1,
2416            start: anchor,
2417            end: anchor,
2418            reversed: false,
2419            goal: SelectionGoal::None,
2420        }]);
2421        buffer.update(cx, |buffer, cx| {
2422            buffer.apply_ops(
2423                [Operation::UpdateSelections {
2424                    selections,
2425                    lamport_timestamp: clock::Lamport {
2426                        replica_id: collab_replica,
2427                        value: 1,
2428                    },
2429                    line_mode: false,
2430                    cursor_shape: CursorShape::Bar,
2431                }],
2432                cx,
2433            );
2434        });
2435    }
2436
2437    fn publish_diagnostics(
2438        uri_path: &'static str,
2439        rows: &[u32],
2440        project: &Entity<Project>,
2441        cx: &mut TestAppContext,
2442    ) {
2443        let diagnostics: Vec<_> = rows
2444            .iter()
2445            .map(|&row| lsp::Diagnostic {
2446                range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
2447                severity: Some(lsp::DiagnosticSeverity::ERROR),
2448                message: format!("error at row {row}"),
2449                ..Default::default()
2450            })
2451            .collect();
2452        project.update(cx, |project, cx| {
2453            project.lsp_store().update(cx, |lsp_store, cx| {
2454                lsp_store
2455                    .update_diagnostics(
2456                        LanguageServerId(0),
2457                        lsp::PublishDiagnosticsParams {
2458                            uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
2459                            diagnostics,
2460                            version: None,
2461                        },
2462                        None,
2463                        language::DiagnosticSourceKind::Pushed,
2464                        &[],
2465                        cx,
2466                    )
2467                    .expect("failed to update diagnostics");
2468            });
2469        });
2470    }
2471
2472    init_test(cx);
2473
2474    let mut lines = String::new();
2475    for i in 0..60 {
2476        lines.push_str(&format!("line {i}\n"));
2477    }
2478
2479    let fs = FakeFs::new(cx.executor());
2480    fs.insert_tree(
2481        "/root",
2482        json!({
2483            "active.txt": lines,
2484            "collab_file.txt": "error here\nsecond line\n",
2485            "free_file.txt": "another error\nsecond line\n",
2486        }),
2487    )
2488    .await;
2489    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2490
2491    let active_buffer = project
2492        .update(cx, |project, cx| {
2493            let path = project
2494                .find_project_path(path!("/root/active.txt"), cx)
2495                .expect("active.txt not found");
2496            project.set_active_path(Some(path.clone()), cx);
2497            project.open_buffer(path, cx)
2498        })
2499        .await
2500        .expect("failed to open active buffer");
2501
2502    set_collaborator_cursor(&active_buffer, 5, cx);
2503
2504    publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
2505
2506    cx.run_until_parked();
2507
2508    let cursor_point = Point::new(25, 0);
2509    let empty_search_range: Range<Point> = Default::default();
2510
2511    let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2512    let result = EditPredictionStore::next_diagnostic_location(
2513        active_buffer.clone(),
2514        &snapshot,
2515        empty_search_range.clone(),
2516        cursor_point,
2517        &project,
2518        &mut cx.to_async(),
2519    )
2520    .await
2521    .expect("next_diagnostic_location failed");
2522
2523    let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
2524    assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
2525    let result_row = result_buffer.read_with(cx, |buffer, _| {
2526        result_anchor.to_point(&buffer.snapshot()).row
2527    });
2528    assert_ne!(
2529        result_row, 3,
2530        "row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
2531    );
2532    assert!(
2533        result_row == 25 || result_row == 50,
2534        "expected row 25 or 50, got {result_row}"
2535    );
2536
2537    let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2538    let near_cursor_point = Point::new(4, 0);
2539    let result_near = EditPredictionStore::next_diagnostic_location(
2540        active_buffer.clone(),
2541        &snapshot_near,
2542        empty_search_range.clone(),
2543        near_cursor_point,
2544        &project,
2545        &mut cx.to_async(),
2546    )
2547    .await
2548    .expect("next_diagnostic_location failed");
2549
2550    let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
2551    let near_row =
2552        active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
2553    assert_eq!(
2554        near_row, 3,
2555        "row 3 should be included when local cursor (row 4) is also near the collaborator"
2556    );
2557
2558    let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2559    let far_cursor_point = Point::new(50, 0);
2560    let result_far = EditPredictionStore::next_diagnostic_location(
2561        active_buffer.clone(),
2562        &snapshot_far,
2563        empty_search_range.clone(),
2564        far_cursor_point,
2565        &project,
2566        &mut cx.to_async(),
2567    )
2568    .await
2569    .expect("next_diagnostic_location failed");
2570
2571    let (_, far_anchor) = result_far.expect("expected a diagnostic location");
2572    let far_row =
2573        active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
2574    assert_eq!(
2575        far_row, 50,
2576        "row 50 is near local cursor (row 50) and far from collaborator, should be picked"
2577    );
2578
2579    publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
2580    publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
2581    cx.run_until_parked();
2582
2583    let collab_buffer = project
2584        .update(cx, |project, cx| {
2585            let path = project
2586                .find_project_path(path!("/root/collab_file.txt"), cx)
2587                .expect("collab_file.txt not found");
2588            project.open_buffer(path, cx)
2589        })
2590        .await
2591        .expect("failed to open collab buffer");
2592
2593    set_collaborator_cursor(&collab_buffer, 0, cx);
2594    cx.run_until_parked();
2595
2596    let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
2597    let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2598    let result_cross = EditPredictionStore::next_diagnostic_location(
2599        active_buffer.clone(),
2600        &snapshot_cross,
2601        no_same_file_search_range,
2602        Point::new(0, 0),
2603        &project,
2604        &mut cx.to_async(),
2605    )
2606    .await
2607    .expect("cross-file next_diagnostic_location failed");
2608
2609    let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
2610    let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
2611        buffer
2612            .file()
2613            .expect("buffer should have a file")
2614            .full_path(cx)
2615    });
2616    assert_eq!(
2617        cross_path,
2618        Path::new(path!("root/free_file.txt")),
2619        "should skip collab_file.txt (has collaborator) and pick free_file.txt"
2620    );
2621}
2622
2623#[gpui::test]
2624async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
2625    let (ep_store, _requests) = init_test_with_fake_client(cx);
2626    let fs = FakeFs::new(cx.executor());
2627
2628    // Buffer with two clearly separated regions:
2629    //   Region A = lines 0-9   (offsets 0..50)
2630    //   Region B = lines 20-29 (offsets 105..155)
2631    // A big gap in between so edits in one region never overlap the other.
2632    let mut content = String::new();
2633    for i in 0..30 {
2634        content.push_str(&format!("line {i:02}\n"));
2635    }
2636
2637    fs.insert_tree(
2638        "/root",
2639        json!({
2640            "foo.md": content.clone()
2641        }),
2642    )
2643    .await;
2644    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2645
2646    let buffer = project
2647        .update(cx, |project, cx| {
2648            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2649            project.open_buffer(path, cx)
2650        })
2651        .await
2652        .unwrap();
2653
2654    type SettledEventRecord = (EditPredictionId, String);
2655    let settled_events: Arc<Mutex<Vec<SettledEventRecord>>> = Arc::new(Mutex::new(Vec::new()));
2656
2657    ep_store.update(cx, |ep_store, cx| {
2658        ep_store.register_buffer(&buffer, &project, cx);
2659
2660        let settled_events = settled_events.clone();
2661        ep_store.settled_event_callback = Some(Box::new(move |id, text| {
2662            settled_events.lock().push((id, text));
2663        }));
2664    });
2665
2666    // --- Phase 1: edit in region A and enqueue prediction A ---
2667
2668    buffer.update(cx, |buffer, cx| {
2669        // Edit at the start of line 0.
2670        buffer.edit(vec![(0..0, "ADDED ")], None, cx);
2671    });
2672    cx.run_until_parked();
2673
2674    let snapshot_a = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2675
2676    // Region A: first 10 lines of the buffer.
2677    let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0));
2678
2679    ep_store.update(cx, |ep_store, cx| {
2680        ep_store.enqueue_settled_prediction(
2681            EditPredictionId("prediction-a".into()),
2682            &project,
2683            &buffer,
2684            &snapshot_a,
2685            editable_region_a.clone(),
2686            None,
2687            cx,
2688        );
2689    });
2690
2691    // --- Phase 2: repeatedly edit in region A to keep it unsettled ---
2692
2693    // Let the worker process the channel message before we start advancing.
2694    cx.run_until_parked();
2695
2696    let mut region_a_edit_offset = 5;
2697    for _ in 0..3 {
2698        // Edit inside region A (not at the boundary) so `last_edit_at` is
2699        // updated before the worker's next wake.
2700        buffer.update(cx, |buffer, cx| {
2701            buffer.edit(
2702                vec![(region_a_edit_offset..region_a_edit_offset, "x")],
2703                None,
2704                cx,
2705            );
2706        });
2707        region_a_edit_offset += 1;
2708        cx.run_until_parked();
2709
2710        cx.executor()
2711            .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 2);
2712        cx.run_until_parked();
2713        assert!(
2714            settled_events.lock().is_empty(),
2715            "no settled events should fire while region A is still being edited"
2716        );
2717    }
2718
2719    // Still nothing settled.
2720    assert!(settled_events.lock().is_empty());
2721
2722    // --- Phase 3: edit in distinct region B, enqueue prediction B ---
2723    // Advance a small amount so B's quiescence window starts later than A's,
2724    // but not so much that A settles (A's last edit was at the start of
2725    // iteration 3, and it needs a full Q to settle).
2726    cx.executor()
2727        .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2728    cx.run_until_parked();
2729    assert!(settled_events.lock().is_empty());
2730
2731    let snapshot_b = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2732    let line_20_offset = snapshot_b.point_to_offset(Point::new(20, 0));
2733
2734    buffer.update(cx, |buffer, cx| {
2735        buffer.edit(vec![(line_20_offset..line_20_offset, "NEW ")], None, cx);
2736    });
2737    cx.run_until_parked();
2738
2739    let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2740    let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0));
2741
2742    ep_store.update(cx, |ep_store, cx| {
2743        ep_store.enqueue_settled_prediction(
2744            EditPredictionId("prediction-b".into()),
2745            &project,
2746            &buffer,
2747            &snapshot_b2,
2748            editable_region_b.clone(),
2749            None,
2750            cx,
2751        );
2752    });
2753
2754    cx.run_until_parked();
2755    assert!(
2756        settled_events.lock().is_empty(),
2757        "neither prediction should have settled yet"
2758    );
2759
2760    // --- Phase 4: let enough time pass for region A to settle ---
2761    // A's last edit was at T_a (during the last loop iteration). The worker is
2762    // sleeping until T_a + Q. We advance just enough to reach that wake time
2763    // (Q/4 since we already advanced Q/4 in phase 3 on top of the loop's
2764    // 3*Q/2). At that point A has been quiet for Q and settles, but B was
2765    // enqueued only Q/4 ago and stays pending.
2766    cx.executor()
2767        .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2768    cx.run_until_parked();
2769
2770    {
2771        let events = settled_events.lock().clone();
2772        assert_eq!(
2773            events.len(),
2774            1,
2775            "prediction and capture_sample for A should have settled, got: {events:?}"
2776        );
2777        assert_eq!(events[0].0, EditPredictionId("prediction-a".into()));
2778    }
2779
2780    // --- Phase 5: let more time pass for region B to settle ---
2781    // B's last edit was Q/4 before A settled. The worker rescheduled to
2782    // B's last_edit_at + Q, which is 3Q/4 from now.
2783    cx.executor()
2784        .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE * 3 / 4);
2785    cx.run_until_parked();
2786
2787    {
2788        let events = settled_events.lock().clone();
2789        assert_eq!(
2790            events.len(),
2791            2,
2792            "both prediction and capture_sample settled events should be emitted for each request, got: {events:?}"
2793        );
2794        assert_eq!(events[1].0, EditPredictionId("prediction-b".into()));
2795    }
2796}
2797
2798#[ctor::ctor]
2799fn init_logger() {
2800    zlog::init_test();
2801}