edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
   3use client::{UserStore, test::FakeServer};
   4use clock::FakeSystemClock;
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
   8    predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
   9};
  10use futures::{
  11    AsyncReadExt, FutureExt, StreamExt,
  12    channel::{mpsc, oneshot},
  13};
  14use gpui::App;
  15use gpui::{
  16    Entity, TestAppContext,
  17    http_client::{FakeHttpClient, Response},
  18};
  19use indoc::indoc;
  20use language::{Anchor, Buffer, CursorShape, Operation, Point, Selection, SelectionGoal};
  21use lsp::LanguageServerId;
  22use parking_lot::Mutex;
  23use pretty_assertions::{assert_eq, assert_matches};
  24use project::{FakeFs, Project};
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{path::Path, sync::Arc, time::Duration};
  28use util::path;
  29use uuid::Uuid;
  30use zeta_prompt::ZetaPromptInput;
  31
  32use crate::{
  33    BufferEditPrediction, EDIT_PREDICTION_SETTLED_QUIESCENCE, EditPredictionId,
  34    EditPredictionStore, REJECT_REQUEST_DEBOUNCE,
  35};
  36
  37#[gpui::test]
  38async fn test_current_state(cx: &mut TestAppContext) {
  39    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  40    let fs = FakeFs::new(cx.executor());
  41    fs.insert_tree(
  42        "/root",
  43        json!({
  44            "1.txt": "Hello!\nHow\nBye\n",
  45            "2.txt": "Hola!\nComo\nAdios\n"
  46        }),
  47    )
  48    .await;
  49    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  50
  51    let buffer1 = project
  52        .update(cx, |project, cx| {
  53            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  54            project.set_active_path(Some(path.clone()), cx);
  55            project.open_buffer(path, cx)
  56        })
  57        .await
  58        .unwrap();
  59    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  60    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  61
  62    ep_store.update(cx, |ep_store, cx| {
  63        ep_store.register_project(&project, cx);
  64        ep_store.register_buffer(&buffer1, &project, cx);
  65    });
  66
  67    // Prediction for current file
  68
  69    ep_store.update(cx, |ep_store, cx| {
  70        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  71    });
  72    let (request, respond_tx) = requests.predict.next().await.unwrap();
  73
  74    respond_tx
  75        .send(model_response(
  76            &request,
  77            indoc! {r"
  78                --- a/root/1.txt
  79                +++ b/root/1.txt
  80                @@ ... @@
  81                 Hello!
  82                -How
  83                +How are you?
  84                 Bye
  85            "},
  86        ))
  87        .unwrap();
  88
  89    cx.run_until_parked();
  90
  91    ep_store.update(cx, |ep_store, cx| {
  92        let prediction = ep_store
  93            .prediction_at(&buffer1, None, &project, cx)
  94            .unwrap();
  95        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  96    });
  97
  98    ep_store.update(cx, |ep_store, cx| {
  99        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
 100    });
 101
 102    // Prediction for diagnostic in another file
 103
 104    let diagnostic = lsp::Diagnostic {
 105        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 106        severity: Some(lsp::DiagnosticSeverity::ERROR),
 107        message: "Sentence is incomplete".to_string(),
 108        ..Default::default()
 109    };
 110
 111    project.update(cx, |project, cx| {
 112        project.lsp_store().update(cx, |lsp_store, cx| {
 113            lsp_store
 114                .update_diagnostics(
 115                    LanguageServerId(0),
 116                    lsp::PublishDiagnosticsParams {
 117                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 118                        diagnostics: vec![diagnostic],
 119                        version: None,
 120                    },
 121                    None,
 122                    language::DiagnosticSourceKind::Pushed,
 123                    &[],
 124                    cx,
 125                )
 126                .unwrap();
 127        });
 128    });
 129
 130    let (request, respond_tx) = requests.predict.next().await.unwrap();
 131    respond_tx
 132        .send(model_response(
 133            &request,
 134            indoc! {r#"
 135                --- a/root/2.txt
 136                +++ b/root/2.txt
 137                @@ ... @@
 138                 Hola!
 139                -Como
 140                +Como estas?
 141                 Adios
 142            "#},
 143        ))
 144        .unwrap();
 145    cx.run_until_parked();
 146
 147    ep_store.update(cx, |ep_store, cx| {
 148        let prediction = ep_store
 149            .prediction_at(&buffer1, None, &project, cx)
 150            .unwrap();
 151        assert_matches!(
 152            prediction,
 153            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 154        );
 155    });
 156
 157    let buffer2 = project
 158        .update(cx, |project, cx| {
 159            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 160            project.open_buffer(path, cx)
 161        })
 162        .await
 163        .unwrap();
 164
 165    ep_store.update(cx, |ep_store, cx| {
 166        let prediction = ep_store
 167            .prediction_at(&buffer2, None, &project, cx)
 168            .unwrap();
 169        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 170    });
 171}
 172
 173#[gpui::test]
 174async fn test_simple_request(cx: &mut TestAppContext) {
 175    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 176    let fs = FakeFs::new(cx.executor());
 177    fs.insert_tree(
 178        "/root",
 179        json!({
 180            "foo.md":  "Hello!\nHow\nBye\n"
 181        }),
 182    )
 183    .await;
 184    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 185
 186    let buffer = project
 187        .update(cx, |project, cx| {
 188            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 189            project.open_buffer(path, cx)
 190        })
 191        .await
 192        .unwrap();
 193    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 194    let position = snapshot.anchor_before(language::Point::new(1, 3));
 195
 196    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 197        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 198    });
 199
 200    let (request, respond_tx) = requests.predict.next().await.unwrap();
 201
 202    // TODO Put back when we have a structured request again
 203    // assert_eq!(
 204    //     request.excerpt_path.as_ref(),
 205    //     Path::new(path!("root/foo.md"))
 206    // );
 207    // assert_eq!(
 208    //     request.cursor_point,
 209    //     Point {
 210    //         line: Line(1),
 211    //         column: 3
 212    //     }
 213    // );
 214
 215    respond_tx
 216        .send(model_response(
 217            &request,
 218            indoc! { r"
 219                --- a/root/foo.md
 220                +++ b/root/foo.md
 221                @@ ... @@
 222                 Hello!
 223                -How
 224                +How are you?
 225                 Bye
 226            "},
 227        ))
 228        .unwrap();
 229
 230    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 231
 232    assert_eq!(prediction.edits.len(), 1);
 233    assert_eq!(
 234        prediction.edits[0].0.to_point(&snapshot).start,
 235        language::Point::new(1, 3)
 236    );
 237    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 238}
 239
 240#[gpui::test]
 241async fn test_request_events(cx: &mut TestAppContext) {
 242    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 243    let fs = FakeFs::new(cx.executor());
 244    fs.insert_tree(
 245        "/root",
 246        json!({
 247            "foo.md": "Hello!\n\nBye\n"
 248        }),
 249    )
 250    .await;
 251    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 252
 253    let buffer = project
 254        .update(cx, |project, cx| {
 255            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 256            project.open_buffer(path, cx)
 257        })
 258        .await
 259        .unwrap();
 260
 261    ep_store.update(cx, |ep_store, cx| {
 262        ep_store.register_buffer(&buffer, &project, cx);
 263    });
 264
 265    buffer.update(cx, |buffer, cx| {
 266        buffer.edit(vec![(7..7, "How")], None, cx);
 267    });
 268
 269    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 270    let position = snapshot.anchor_before(language::Point::new(1, 3));
 271
 272    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 273        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 274    });
 275
 276    let (request, respond_tx) = requests.predict.next().await.unwrap();
 277
 278    let prompt = prompt_from_request(&request);
 279    assert!(
 280        prompt.contains(indoc! {"
 281        --- a/root/foo.md
 282        +++ b/root/foo.md
 283        @@ -1,3 +1,3 @@
 284         Hello!
 285        -
 286        +How
 287         Bye
 288    "}),
 289        "{prompt}"
 290    );
 291
 292    respond_tx
 293        .send(model_response(
 294            &request,
 295            indoc! {r#"
 296                --- a/root/foo.md
 297                +++ b/root/foo.md
 298                @@ ... @@
 299                 Hello!
 300                -How
 301                +How are you?
 302                 Bye
 303        "#},
 304        ))
 305        .unwrap();
 306
 307    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 308
 309    assert_eq!(prediction.edits.len(), 1);
 310    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 311}
 312
 313#[gpui::test]
 314async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 315    let (ep_store, _requests) = init_test_with_fake_client(cx);
 316    let fs = FakeFs::new(cx.executor());
 317    fs.insert_tree(
 318        "/root",
 319        json!({
 320            "foo.md": "Hello!\n\nBye\n"
 321        }),
 322    )
 323    .await;
 324    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 325
 326    let buffer = project
 327        .update(cx, |project, cx| {
 328            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 329            project.open_buffer(path, cx)
 330        })
 331        .await
 332        .unwrap();
 333
 334    ep_store.update(cx, |ep_store, cx| {
 335        ep_store.register_buffer(&buffer, &project, cx);
 336    });
 337
 338    // First burst: insert "How"
 339    buffer.update(cx, |buffer, cx| {
 340        buffer.edit(vec![(7..7, "How")], None, cx);
 341    });
 342
 343    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 344    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 345    cx.run_until_parked();
 346
 347    // Second burst: append " are you?" immediately after "How" on the same line.
 348    //
 349    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 350    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 351    buffer.update(cx, |buffer, cx| {
 352        buffer.edit(vec![(10..10, " are you?")], None, cx);
 353    });
 354
 355    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 356    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 357    buffer.update(cx, |buffer, cx| {
 358        buffer.edit(vec![(19..19, "!")], None, cx);
 359    });
 360
 361    // With time-based splitting, there are two distinct events.
 362    let events = ep_store.update(cx, |ep_store, cx| {
 363        ep_store.edit_history_for_project(&project, cx)
 364    });
 365    assert_eq!(events.len(), 2);
 366    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 367    assert_eq!(
 368        diff.as_str(),
 369        indoc! {"
 370            @@ -1,3 +1,3 @@
 371             Hello!
 372            -
 373            +How
 374             Bye
 375        "}
 376    );
 377
 378    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 379    assert_eq!(
 380        diff.as_str(),
 381        indoc! {"
 382            @@ -1,3 +1,3 @@
 383             Hello!
 384            -How
 385            +How are you?!
 386             Bye
 387        "}
 388    );
 389}
 390
 391#[gpui::test]
 392async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
 393    let (ep_store, _requests) = init_test_with_fake_client(cx);
 394    let fs = FakeFs::new(cx.executor());
 395
 396    // Create a file with 30 lines to test line-based coalescing
 397    let content = (1..=30)
 398        .map(|i| format!("Line {}\n", i))
 399        .collect::<String>();
 400    fs.insert_tree(
 401        "/root",
 402        json!({
 403            "foo.md": content
 404        }),
 405    )
 406    .await;
 407    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 408
 409    let buffer = project
 410        .update(cx, |project, cx| {
 411            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 412            project.open_buffer(path, cx)
 413        })
 414        .await
 415        .unwrap();
 416
 417    ep_store.update(cx, |ep_store, cx| {
 418        ep_store.register_buffer(&buffer, &project, cx);
 419    });
 420
 421    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 422    buffer.update(cx, |buffer, cx| {
 423        let start = Point::new(10, 0).to_offset(buffer);
 424        let end = Point::new(13, 0).to_offset(buffer);
 425        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 426    });
 427
 428    let events = ep_store.update(cx, |ep_store, cx| {
 429        ep_store.edit_history_for_project(&project, cx)
 430    });
 431    assert_eq!(
 432        render_events(&events),
 433        indoc! {"
 434            @@ -8,9 +8,8 @@
 435             Line 8
 436             Line 9
 437             Line 10
 438            -Line 11
 439            -Line 12
 440            -Line 13
 441            +Middle A
 442            +Middle B
 443             Line 14
 444             Line 15
 445             Line 16
 446        "},
 447        "After first edit"
 448    );
 449
 450    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 451    // This tests that coalescing considers the START of the existing range
 452    buffer.update(cx, |buffer, cx| {
 453        let offset = Point::new(5, 0).to_offset(buffer);
 454        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 455    });
 456
 457    let events = ep_store.update(cx, |ep_store, cx| {
 458        ep_store.edit_history_for_project(&project, cx)
 459    });
 460    assert_eq!(
 461        render_events(&events),
 462        indoc! {"
 463            @@ -3,14 +3,14 @@
 464             Line 3
 465             Line 4
 466             Line 5
 467            +Above
 468             Line 6
 469             Line 7
 470             Line 8
 471             Line 9
 472             Line 10
 473            -Line 11
 474            -Line 12
 475            -Line 13
 476            +Middle A
 477            +Middle B
 478             Line 14
 479             Line 15
 480             Line 16
 481        "},
 482        "After inserting above (should coalesce)"
 483    );
 484
 485    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 486    // This tests that coalescing considers the END of the existing range
 487    buffer.update(cx, |buffer, cx| {
 488        let offset = Point::new(14, 0).to_offset(buffer);
 489        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 490    });
 491
 492    let events = ep_store.update(cx, |ep_store, cx| {
 493        ep_store.edit_history_for_project(&project, cx)
 494    });
 495    assert_eq!(
 496        render_events(&events),
 497        indoc! {"
 498            @@ -3,15 +3,16 @@
 499             Line 3
 500             Line 4
 501             Line 5
 502            +Above
 503             Line 6
 504             Line 7
 505             Line 8
 506             Line 9
 507             Line 10
 508            -Line 11
 509            -Line 12
 510            -Line 13
 511            +Middle A
 512            +Middle B
 513             Line 14
 514            +Below
 515             Line 15
 516             Line 16
 517             Line 17
 518        "},
 519        "After inserting below (should coalesce)"
 520    );
 521
 522    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 523    // This should NOT coalesce - creates a new event
 524    buffer.update(cx, |buffer, cx| {
 525        let offset = Point::new(25, 0).to_offset(buffer);
 526        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 527    });
 528
 529    let events = ep_store.update(cx, |ep_store, cx| {
 530        ep_store.edit_history_for_project(&project, cx)
 531    });
 532    assert_eq!(
 533        render_events(&events),
 534        indoc! {"
 535            @@ -3,15 +3,16 @@
 536             Line 3
 537             Line 4
 538             Line 5
 539            +Above
 540             Line 6
 541             Line 7
 542             Line 8
 543             Line 9
 544             Line 10
 545            -Line 11
 546            -Line 12
 547            -Line 13
 548            +Middle A
 549            +Middle B
 550             Line 14
 551            +Below
 552             Line 15
 553             Line 16
 554             Line 17
 555
 556            ---
 557            @@ -23,6 +23,7 @@
 558             Line 22
 559             Line 23
 560             Line 24
 561            +Far below
 562             Line 25
 563             Line 26
 564             Line 27
 565        "},
 566        "After inserting far below (should NOT coalesce)"
 567    );
 568}
 569
 570fn render_events(events: &[StoredEvent]) -> String {
 571    events
 572        .iter()
 573        .map(|e| {
 574            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 575            diff.as_str()
 576        })
 577        .collect::<Vec<_>>()
 578        .join("\n---\n")
 579}
 580
 581fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
 582    events
 583        .iter()
 584        .map(|e| {
 585            let zeta_prompt::Event::BufferChange {
 586                diff, predicted, ..
 587            } = e.event.as_ref();
 588            let prefix = if *predicted { "predicted" } else { "manual" };
 589            format!("{}\n{}", prefix, diff)
 590        })
 591        .collect()
 592}
 593
 594#[gpui::test]
 595async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
 596    let (ep_store, _requests) = init_test_with_fake_client(cx);
 597    let fs = FakeFs::new(cx.executor());
 598    fs.insert_tree(
 599        "/root",
 600        json!({
 601            "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
 602        }),
 603    )
 604    .await;
 605    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 606
 607    let buffer = project
 608        .update(cx, |project, cx| {
 609            let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
 610            project.open_buffer(path, cx)
 611        })
 612        .await
 613        .unwrap();
 614
 615    ep_store.update(cx, |ep_store, cx| {
 616        ep_store.register_buffer(&buffer, &project, cx);
 617    });
 618
 619    // Case 1: Manual edits have `predicted` set to false.
 620    buffer.update(cx, |buffer, cx| {
 621        buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
 622    });
 623
 624    let events = ep_store.update(cx, |ep_store, cx| {
 625        ep_store.edit_history_for_project(&project, cx)
 626    });
 627
 628    assert_eq!(
 629        render_events_with_predicted(&events),
 630        vec![indoc! {"
 631            manual
 632            @@ -1,4 +1,4 @@
 633            -line 0
 634            +LINE ZERO
 635             line 1
 636             line 2
 637             line 3
 638        "}]
 639    );
 640
 641    // Case 2: Multiple successive manual edits near each other are merged into one
 642    // event with `predicted` set to false.
 643    buffer.update(cx, |buffer, cx| {
 644        let offset = Point::new(1, 0).to_offset(buffer);
 645        let end = Point::new(1, 6).to_offset(buffer);
 646        buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
 647    });
 648
 649    let events = ep_store.update(cx, |ep_store, cx| {
 650        ep_store.edit_history_for_project(&project, cx)
 651    });
 652    assert_eq!(
 653        render_events_with_predicted(&events),
 654        vec![indoc! {"
 655            manual
 656            @@ -1,5 +1,5 @@
 657            -line 0
 658            -line 1
 659            +LINE ZERO
 660            +LINE ONE
 661             line 2
 662             line 3
 663             line 4
 664        "}]
 665    );
 666
 667    // Case 3: Accepted predictions have `predicted` set to true.
 668    // Case 5: A manual edit that follows a predicted edit is not merged with the
 669    // predicted edit, even if it is nearby.
 670    ep_store.update(cx, |ep_store, cx| {
 671        buffer.update(cx, |buffer, cx| {
 672            let offset = Point::new(2, 0).to_offset(buffer);
 673            let end = Point::new(2, 6).to_offset(buffer);
 674            buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
 675        });
 676        ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
 677    });
 678
 679    let events = ep_store.update(cx, |ep_store, cx| {
 680        ep_store.edit_history_for_project(&project, cx)
 681    });
 682    assert_eq!(
 683        render_events_with_predicted(&events),
 684        vec![
 685            indoc! {"
 686                manual
 687                @@ -1,5 +1,5 @@
 688                -line 0
 689                -line 1
 690                +LINE ZERO
 691                +LINE ONE
 692                 line 2
 693                 line 3
 694                 line 4
 695            "},
 696            indoc! {"
 697                predicted
 698                @@ -1,6 +1,6 @@
 699                 LINE ZERO
 700                 LINE ONE
 701                -line 2
 702                +LINE TWO
 703                 line 3
 704                 line 4
 705                 line 5
 706            "}
 707        ]
 708    );
 709
 710    // Case 4: Multiple successive accepted predictions near each other are merged
 711    // into one event with `predicted` set to true.
 712    ep_store.update(cx, |ep_store, cx| {
 713        buffer.update(cx, |buffer, cx| {
 714            let offset = Point::new(3, 0).to_offset(buffer);
 715            let end = Point::new(3, 6).to_offset(buffer);
 716            buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
 717        });
 718        ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
 719    });
 720
 721    let events = ep_store.update(cx, |ep_store, cx| {
 722        ep_store.edit_history_for_project(&project, cx)
 723    });
 724    assert_eq!(
 725        render_events_with_predicted(&events),
 726        vec![
 727            indoc! {"
 728                manual
 729                @@ -1,5 +1,5 @@
 730                -line 0
 731                -line 1
 732                +LINE ZERO
 733                +LINE ONE
 734                 line 2
 735                 line 3
 736                 line 4
 737            "},
 738            indoc! {"
 739                predicted
 740                @@ -1,7 +1,7 @@
 741                 LINE ZERO
 742                 LINE ONE
 743                -line 2
 744                -line 3
 745                +LINE TWO
 746                +LINE THREE
 747                 line 4
 748                 line 5
 749                 line 6
 750            "}
 751        ]
 752    );
 753
 754    // Case 5 (continued): A manual edit that follows a predicted edit is not merged
 755    // with the predicted edit, even if it is nearby.
 756    buffer.update(cx, |buffer, cx| {
 757        let offset = Point::new(4, 0).to_offset(buffer);
 758        let end = Point::new(4, 6).to_offset(buffer);
 759        buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
 760    });
 761
 762    let events = ep_store.update(cx, |ep_store, cx| {
 763        ep_store.edit_history_for_project(&project, cx)
 764    });
 765    assert_eq!(
 766        render_events_with_predicted(&events),
 767        vec![
 768            indoc! {"
 769                manual
 770                @@ -1,5 +1,5 @@
 771                -line 0
 772                -line 1
 773                +LINE ZERO
 774                +LINE ONE
 775                 line 2
 776                 line 3
 777                 line 4
 778            "},
 779            indoc! {"
 780                predicted
 781                @@ -1,7 +1,7 @@
 782                 LINE ZERO
 783                 LINE ONE
 784                -line 2
 785                -line 3
 786                +LINE TWO
 787                +LINE THREE
 788                 line 4
 789                 line 5
 790                 line 6
 791            "},
 792            indoc! {"
 793                manual
 794                @@ -2,7 +2,7 @@
 795                 LINE ONE
 796                 LINE TWO
 797                 LINE THREE
 798                -line 4
 799                +LINE FOUR
 800                 line 5
 801                 line 6
 802                 line 7
 803            "}
 804        ]
 805    );
 806
 807    // Case 6: If we then perform a manual edit at a *different* location (more than
 808    // 8 lines away), then the edits at the prior location can be merged with each
 809    // other, even if some are predicted and some are not. `predicted` means all
 810    // constituent edits were predicted.
 811    buffer.update(cx, |buffer, cx| {
 812        let offset = Point::new(14, 0).to_offset(buffer);
 813        let end = Point::new(14, 7).to_offset(buffer);
 814        buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
 815    });
 816
 817    let events = ep_store.update(cx, |ep_store, cx| {
 818        ep_store.edit_history_for_project(&project, cx)
 819    });
 820    assert_eq!(
 821        render_events_with_predicted(&events),
 822        vec![
 823            indoc! {"
 824                manual
 825                @@ -1,8 +1,8 @@
 826                -line 0
 827                -line 1
 828                -line 2
 829                -line 3
 830                -line 4
 831                +LINE ZERO
 832                +LINE ONE
 833                +LINE TWO
 834                +LINE THREE
 835                +LINE FOUR
 836                 line 5
 837                 line 6
 838                 line 7
 839            "},
 840            indoc! {"
 841                manual
 842                @@ -12,4 +12,4 @@
 843                 line 11
 844                 line 12
 845                 line 13
 846                -line 14
 847                +LINE FOURTEEN
 848            "}
 849        ]
 850    );
 851}
 852
 853#[gpui::test]
 854async fn test_empty_prediction(cx: &mut TestAppContext) {
 855    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 856    let fs = FakeFs::new(cx.executor());
 857    fs.insert_tree(
 858        "/root",
 859        json!({
 860            "foo.md":  "Hello!\nHow\nBye\n"
 861        }),
 862    )
 863    .await;
 864    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 865
 866    let buffer = project
 867        .update(cx, |project, cx| {
 868            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 869            project.open_buffer(path, cx)
 870        })
 871        .await
 872        .unwrap();
 873    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 874    let position = snapshot.anchor_before(language::Point::new(1, 3));
 875
 876    ep_store.update(cx, |ep_store, cx| {
 877        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 878    });
 879
 880    let (request, respond_tx) = requests.predict.next().await.unwrap();
 881    let response = model_response(&request, "");
 882    let id = response.request_id.clone();
 883    respond_tx.send(response).unwrap();
 884
 885    cx.run_until_parked();
 886
 887    ep_store.update(cx, |ep_store, cx| {
 888        assert!(
 889            ep_store
 890                .prediction_at(&buffer, None, &project, cx)
 891                .is_none()
 892        );
 893    });
 894
 895    // prediction is reported as rejected
 896    let (reject_request, _) = requests.reject.next().await.unwrap();
 897
 898    assert_eq!(
 899        &reject_request.rejections,
 900        &[EditPredictionRejection {
 901            request_id: id,
 902            reason: EditPredictionRejectReason::Empty,
 903            was_shown: false,
 904            model_version: None,
 905        }]
 906    );
 907}
 908
 909#[gpui::test]
 910async fn test_interpolated_empty(cx: &mut TestAppContext) {
 911    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 912    let fs = FakeFs::new(cx.executor());
 913    fs.insert_tree(
 914        "/root",
 915        json!({
 916            "foo.md":  "Hello!\nHow\nBye\n"
 917        }),
 918    )
 919    .await;
 920    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 921
 922    let buffer = project
 923        .update(cx, |project, cx| {
 924            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 925            project.open_buffer(path, cx)
 926        })
 927        .await
 928        .unwrap();
 929    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 930    let position = snapshot.anchor_before(language::Point::new(1, 3));
 931
 932    ep_store.update(cx, |ep_store, cx| {
 933        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 934    });
 935
 936    let (request, respond_tx) = requests.predict.next().await.unwrap();
 937
 938    buffer.update(cx, |buffer, cx| {
 939        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 940    });
 941
 942    let response = model_response(&request, SIMPLE_DIFF);
 943    let id = response.request_id.clone();
 944    respond_tx.send(response).unwrap();
 945
 946    cx.run_until_parked();
 947
 948    ep_store.update(cx, |ep_store, cx| {
 949        assert!(
 950            ep_store
 951                .prediction_at(&buffer, None, &project, cx)
 952                .is_none()
 953        );
 954    });
 955
 956    // prediction is reported as rejected
 957    let (reject_request, _) = requests.reject.next().await.unwrap();
 958
 959    assert_eq!(
 960        &reject_request.rejections,
 961        &[EditPredictionRejection {
 962            request_id: id,
 963            reason: EditPredictionRejectReason::InterpolatedEmpty,
 964            was_shown: false,
 965            model_version: None,
 966        }]
 967    );
 968}
 969
 970const SIMPLE_DIFF: &str = indoc! { r"
 971    --- a/root/foo.md
 972    +++ b/root/foo.md
 973    @@ ... @@
 974     Hello!
 975    -How
 976    +How are you?
 977     Bye
 978"};
 979
 980#[gpui::test]
 981async fn test_replace_current(cx: &mut TestAppContext) {
 982    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 983    let fs = FakeFs::new(cx.executor());
 984    fs.insert_tree(
 985        "/root",
 986        json!({
 987            "foo.md":  "Hello!\nHow\nBye\n"
 988        }),
 989    )
 990    .await;
 991    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 992
 993    let buffer = project
 994        .update(cx, |project, cx| {
 995            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 996            project.open_buffer(path, cx)
 997        })
 998        .await
 999        .unwrap();
1000    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1001    let position = snapshot.anchor_before(language::Point::new(1, 3));
1002
1003    ep_store.update(cx, |ep_store, cx| {
1004        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1005    });
1006
1007    let (request, respond_tx) = requests.predict.next().await.unwrap();
1008    let first_response = model_response(&request, SIMPLE_DIFF);
1009    let first_id = first_response.request_id.clone();
1010    respond_tx.send(first_response).unwrap();
1011
1012    cx.run_until_parked();
1013
1014    ep_store.update(cx, |ep_store, cx| {
1015        assert_eq!(
1016            ep_store
1017                .prediction_at(&buffer, None, &project, cx)
1018                .unwrap()
1019                .id
1020                .0,
1021            first_id
1022        );
1023    });
1024
1025    // a second request is triggered
1026    ep_store.update(cx, |ep_store, cx| {
1027        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1028    });
1029
1030    let (request, respond_tx) = requests.predict.next().await.unwrap();
1031    let second_response = model_response(&request, SIMPLE_DIFF);
1032    let second_id = second_response.request_id.clone();
1033    respond_tx.send(second_response).unwrap();
1034
1035    cx.run_until_parked();
1036
1037    ep_store.update(cx, |ep_store, cx| {
1038        // second replaces first
1039        assert_eq!(
1040            ep_store
1041                .prediction_at(&buffer, None, &project, cx)
1042                .unwrap()
1043                .id
1044                .0,
1045            second_id
1046        );
1047    });
1048
1049    // first is reported as replaced
1050    let (reject_request, _) = requests.reject.next().await.unwrap();
1051
1052    assert_eq!(
1053        &reject_request.rejections,
1054        &[EditPredictionRejection {
1055            request_id: first_id,
1056            reason: EditPredictionRejectReason::Replaced,
1057            was_shown: false,
1058            model_version: None,
1059        }]
1060    );
1061}
1062
1063#[gpui::test]
1064async fn test_current_preferred(cx: &mut TestAppContext) {
1065    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1066    let fs = FakeFs::new(cx.executor());
1067    fs.insert_tree(
1068        "/root",
1069        json!({
1070            "foo.md":  "Hello!\nHow\nBye\n"
1071        }),
1072    )
1073    .await;
1074    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1075
1076    let buffer = project
1077        .update(cx, |project, cx| {
1078            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1079            project.open_buffer(path, cx)
1080        })
1081        .await
1082        .unwrap();
1083    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1084    let position = snapshot.anchor_before(language::Point::new(1, 3));
1085
1086    ep_store.update(cx, |ep_store, cx| {
1087        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1088    });
1089
1090    let (request, respond_tx) = requests.predict.next().await.unwrap();
1091    let first_response = model_response(&request, SIMPLE_DIFF);
1092    let first_id = first_response.request_id.clone();
1093    respond_tx.send(first_response).unwrap();
1094
1095    cx.run_until_parked();
1096
1097    ep_store.update(cx, |ep_store, cx| {
1098        assert_eq!(
1099            ep_store
1100                .prediction_at(&buffer, None, &project, cx)
1101                .unwrap()
1102                .id
1103                .0,
1104            first_id
1105        );
1106    });
1107
1108    // a second request is triggered
1109    ep_store.update(cx, |ep_store, cx| {
1110        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1111    });
1112
1113    let (request, respond_tx) = requests.predict.next().await.unwrap();
1114    // worse than current prediction
1115    let second_response = model_response(
1116        &request,
1117        indoc! { r"
1118            --- a/root/foo.md
1119            +++ b/root/foo.md
1120            @@ ... @@
1121             Hello!
1122            -How
1123            +How are
1124             Bye
1125        "},
1126    );
1127    let second_id = second_response.request_id.clone();
1128    respond_tx.send(second_response).unwrap();
1129
1130    cx.run_until_parked();
1131
1132    ep_store.update(cx, |ep_store, cx| {
1133        // first is preferred over second
1134        assert_eq!(
1135            ep_store
1136                .prediction_at(&buffer, None, &project, cx)
1137                .unwrap()
1138                .id
1139                .0,
1140            first_id
1141        );
1142    });
1143
1144    // second is reported as rejected
1145    let (reject_request, _) = requests.reject.next().await.unwrap();
1146
1147    assert_eq!(
1148        &reject_request.rejections,
1149        &[EditPredictionRejection {
1150            request_id: second_id,
1151            reason: EditPredictionRejectReason::CurrentPreferred,
1152            was_shown: false,
1153            model_version: None,
1154        }]
1155    );
1156}
1157
1158#[gpui::test]
1159async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1160    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1161    let fs = FakeFs::new(cx.executor());
1162    fs.insert_tree(
1163        "/root",
1164        json!({
1165            "foo.md":  "Hello!\nHow\nBye\n"
1166        }),
1167    )
1168    .await;
1169    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1170
1171    let buffer = project
1172        .update(cx, |project, cx| {
1173            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1174            project.open_buffer(path, cx)
1175        })
1176        .await
1177        .unwrap();
1178    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1179    let position = snapshot.anchor_before(language::Point::new(1, 3));
1180
1181    // start two refresh tasks
1182    ep_store.update(cx, |ep_store, cx| {
1183        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1184    });
1185
1186    let (request1, respond_first) = requests.predict.next().await.unwrap();
1187
1188    ep_store.update(cx, |ep_store, cx| {
1189        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1190    });
1191
1192    let (request, respond_second) = requests.predict.next().await.unwrap();
1193
1194    // wait for throttle
1195    cx.run_until_parked();
1196
1197    // second responds first
1198    let second_response = model_response(&request, SIMPLE_DIFF);
1199    let second_id = second_response.request_id.clone();
1200    respond_second.send(second_response).unwrap();
1201
1202    cx.run_until_parked();
1203
1204    ep_store.update(cx, |ep_store, cx| {
1205        // current prediction is second
1206        assert_eq!(
1207            ep_store
1208                .prediction_at(&buffer, None, &project, cx)
1209                .unwrap()
1210                .id
1211                .0,
1212            second_id
1213        );
1214    });
1215
1216    let first_response = model_response(&request1, SIMPLE_DIFF);
1217    let first_id = first_response.request_id.clone();
1218    respond_first.send(first_response).unwrap();
1219
1220    cx.run_until_parked();
1221
1222    ep_store.update(cx, |ep_store, cx| {
1223        // current prediction is still second, since first was cancelled
1224        assert_eq!(
1225            ep_store
1226                .prediction_at(&buffer, None, &project, cx)
1227                .unwrap()
1228                .id
1229                .0,
1230            second_id
1231        );
1232    });
1233
1234    // first is reported as rejected
1235    let (reject_request, _) = requests.reject.next().await.unwrap();
1236
1237    cx.run_until_parked();
1238
1239    assert_eq!(
1240        &reject_request.rejections,
1241        &[EditPredictionRejection {
1242            request_id: first_id,
1243            reason: EditPredictionRejectReason::Canceled,
1244            was_shown: false,
1245            model_version: None,
1246        }]
1247    );
1248}
1249
1250#[gpui::test]
1251async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1252    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1253    let fs = FakeFs::new(cx.executor());
1254    fs.insert_tree(
1255        "/root",
1256        json!({
1257            "foo.md":  "Hello!\nHow\nBye\n"
1258        }),
1259    )
1260    .await;
1261    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1262
1263    let buffer = project
1264        .update(cx, |project, cx| {
1265            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1266            project.open_buffer(path, cx)
1267        })
1268        .await
1269        .unwrap();
1270    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1271    let position = snapshot.anchor_before(language::Point::new(1, 3));
1272
1273    // start two refresh tasks
1274    ep_store.update(cx, |ep_store, cx| {
1275        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1276    });
1277
1278    let (request1, respond_first) = requests.predict.next().await.unwrap();
1279
1280    ep_store.update(cx, |ep_store, cx| {
1281        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1282    });
1283
1284    let (request2, respond_second) = requests.predict.next().await.unwrap();
1285
1286    // wait for throttle, so requests are sent
1287    cx.run_until_parked();
1288
1289    ep_store.update(cx, |ep_store, cx| {
1290        // start a third request
1291        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1292
1293        // 2 are pending, so 2nd is cancelled
1294        assert_eq!(
1295            ep_store
1296                .get_or_init_project(&project, cx)
1297                .cancelled_predictions
1298                .iter()
1299                .copied()
1300                .collect::<Vec<_>>(),
1301            [1]
1302        );
1303    });
1304
1305    // wait for throttle
1306    cx.run_until_parked();
1307
1308    let (request3, respond_third) = requests.predict.next().await.unwrap();
1309
1310    let first_response = model_response(&request1, SIMPLE_DIFF);
1311    let first_id = first_response.request_id.clone();
1312    respond_first.send(first_response).unwrap();
1313
1314    cx.run_until_parked();
1315
1316    ep_store.update(cx, |ep_store, cx| {
1317        // current prediction is first
1318        assert_eq!(
1319            ep_store
1320                .prediction_at(&buffer, None, &project, cx)
1321                .unwrap()
1322                .id
1323                .0,
1324            first_id
1325        );
1326    });
1327
1328    let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1329    let cancelled_id = cancelled_response.request_id.clone();
1330    respond_second.send(cancelled_response).unwrap();
1331
1332    cx.run_until_parked();
1333
1334    ep_store.update(cx, |ep_store, cx| {
1335        // current prediction is still first, since second was cancelled
1336        assert_eq!(
1337            ep_store
1338                .prediction_at(&buffer, None, &project, cx)
1339                .unwrap()
1340                .id
1341                .0,
1342            first_id
1343        );
1344    });
1345
1346    let third_response = model_response(&request3, SIMPLE_DIFF);
1347    let third_response_id = third_response.request_id.clone();
1348    respond_third.send(third_response).unwrap();
1349
1350    cx.run_until_parked();
1351
1352    ep_store.update(cx, |ep_store, cx| {
1353        // third completes and replaces first
1354        assert_eq!(
1355            ep_store
1356                .prediction_at(&buffer, None, &project, cx)
1357                .unwrap()
1358                .id
1359                .0,
1360            third_response_id
1361        );
1362    });
1363
1364    // second is reported as rejected
1365    let (reject_request, _) = requests.reject.next().await.unwrap();
1366
1367    cx.run_until_parked();
1368
1369    assert_eq!(
1370        &reject_request.rejections,
1371        &[
1372            EditPredictionRejection {
1373                request_id: cancelled_id,
1374                reason: EditPredictionRejectReason::Canceled,
1375                was_shown: false,
1376                model_version: None,
1377            },
1378            EditPredictionRejection {
1379                request_id: first_id,
1380                reason: EditPredictionRejectReason::Replaced,
1381                was_shown: false,
1382                model_version: None,
1383            }
1384        ]
1385    );
1386}
1387
1388#[gpui::test]
1389async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
1390    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1391
1392    let fs = FakeFs::new(cx.executor());
1393    fs.insert_tree(
1394        "/root",
1395        json!({
1396            "foo.md":  "Hello!\nHow\nBye\n",
1397            "bar.md": "Hola!\nComo\nAdios\n"
1398        }),
1399    )
1400    .await;
1401    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1402
1403    let buffer = project
1404        .update(cx, |project, cx| {
1405            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1406            project.set_active_path(Some(path.clone()), cx);
1407            project.open_buffer(path, cx)
1408        })
1409        .await
1410        .unwrap();
1411    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1412    let position = snapshot.anchor_before(language::Point::new(1, 3));
1413
1414    ep_store.update(cx, |ep_store, cx| {
1415        ep_store.register_project(&project, cx);
1416        ep_store.register_buffer(&buffer, &project, cx);
1417    });
1418
1419    // First edit request - no prior edit, so not throttled.
1420    ep_store.update(cx, |ep_store, cx| {
1421        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1422    });
1423    let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
1424    edit_response_tx.send(empty_response()).unwrap();
1425    cx.run_until_parked();
1426
1427    let diagnostic = lsp::Diagnostic {
1428        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1429        severity: Some(lsp::DiagnosticSeverity::ERROR),
1430        message: "Sentence is incomplete".to_string(),
1431        ..Default::default()
1432    };
1433
1434    // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
1435    project.update(cx, |project, cx| {
1436        project.lsp_store().update(cx, |lsp_store, cx| {
1437            lsp_store
1438                .update_diagnostics(
1439                    LanguageServerId(0),
1440                    lsp::PublishDiagnosticsParams {
1441                        uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
1442                        diagnostics: vec![diagnostic],
1443                        version: None,
1444                    },
1445                    None,
1446                    language::DiagnosticSourceKind::Pushed,
1447                    &[],
1448                    cx,
1449                )
1450                .unwrap();
1451        });
1452    });
1453    let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
1454    jump_response_tx.send(empty_response()).unwrap();
1455    cx.run_until_parked();
1456
1457    // Second edit request - should be throttled by the first edit.
1458    ep_store.update(cx, |ep_store, cx| {
1459        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1460    });
1461    assert_no_predict_request_ready(&mut requests.predict);
1462
1463    // Second jump request - should be throttled by the first jump.
1464    ep_store.update(cx, |ep_store, cx| {
1465        ep_store.refresh_prediction_from_diagnostics(
1466            project.clone(),
1467            DiagnosticSearchScope::Global,
1468            cx,
1469        );
1470    });
1471    assert_no_predict_request_ready(&mut requests.predict);
1472
1473    // Wait for both throttles to expire.
1474    cx.background_executor
1475        .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
1476    cx.background_executor.run_until_parked();
1477    cx.run_until_parked();
1478
1479    // Both requests should now go through.
1480    let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
1481    response_tx_1.send(empty_response()).unwrap();
1482    cx.run_until_parked();
1483
1484    let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
1485    response_tx_2.send(empty_response()).unwrap();
1486    cx.run_until_parked();
1487}
1488
1489#[gpui::test]
1490async fn test_rejections_flushing(cx: &mut TestAppContext) {
1491    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1492
1493    ep_store.update(cx, |ep_store, cx| {
1494        ep_store.reject_prediction(
1495            EditPredictionId("test-1".into()),
1496            EditPredictionRejectReason::Discarded,
1497            false,
1498            None,
1499            cx,
1500        );
1501        ep_store.reject_prediction(
1502            EditPredictionId("test-2".into()),
1503            EditPredictionRejectReason::Canceled,
1504            true,
1505            None,
1506            cx,
1507        );
1508    });
1509
1510    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1511    cx.run_until_parked();
1512
1513    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1514    respond_tx.send(()).unwrap();
1515
1516    // batched
1517    assert_eq!(reject_request.rejections.len(), 2);
1518    assert_eq!(
1519        reject_request.rejections[0],
1520        EditPredictionRejection {
1521            request_id: "test-1".to_string(),
1522            reason: EditPredictionRejectReason::Discarded,
1523            was_shown: false,
1524            model_version: None,
1525        }
1526    );
1527    assert_eq!(
1528        reject_request.rejections[1],
1529        EditPredictionRejection {
1530            request_id: "test-2".to_string(),
1531            reason: EditPredictionRejectReason::Canceled,
1532            was_shown: true,
1533            model_version: None,
1534        }
1535    );
1536
1537    // Reaching batch size limit sends without debounce
1538    ep_store.update(cx, |ep_store, cx| {
1539        for i in 0..70 {
1540            ep_store.reject_prediction(
1541                EditPredictionId(format!("batch-{}", i).into()),
1542                EditPredictionRejectReason::Discarded,
1543                false,
1544                None,
1545                cx,
1546            );
1547        }
1548    });
1549
1550    // First MAX/2 items are sent immediately
1551    cx.run_until_parked();
1552    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1553    respond_tx.send(()).unwrap();
1554
1555    assert_eq!(reject_request.rejections.len(), 50);
1556    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1557    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1558
1559    // Remaining items are debounced with the next batch
1560    cx.executor().advance_clock(Duration::from_secs(15));
1561    cx.run_until_parked();
1562
1563    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1564    respond_tx.send(()).unwrap();
1565
1566    assert_eq!(reject_request.rejections.len(), 20);
1567    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1568    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1569
1570    // Request failure
1571    ep_store.update(cx, |ep_store, cx| {
1572        ep_store.reject_prediction(
1573            EditPredictionId("retry-1".into()),
1574            EditPredictionRejectReason::Discarded,
1575            false,
1576            None,
1577            cx,
1578        );
1579    });
1580
1581    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1582    cx.run_until_parked();
1583
1584    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1585    assert_eq!(reject_request.rejections.len(), 1);
1586    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1587    // Simulate failure
1588    drop(_respond_tx);
1589
1590    // Add another rejection
1591    ep_store.update(cx, |ep_store, cx| {
1592        ep_store.reject_prediction(
1593            EditPredictionId("retry-2".into()),
1594            EditPredictionRejectReason::Discarded,
1595            false,
1596            None,
1597            cx,
1598        );
1599    });
1600
1601    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1602    cx.run_until_parked();
1603
1604    // Retry should include both the failed item and the new one
1605    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1606    respond_tx.send(()).unwrap();
1607
1608    assert_eq!(reject_request.rejections.len(), 2);
1609    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1610    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1611}
1612
1613// Skipped until we start including diagnostics in prompt
1614// #[gpui::test]
1615// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1616//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1617//     let fs = FakeFs::new(cx.executor());
1618//     fs.insert_tree(
1619//         "/root",
1620//         json!({
1621//             "foo.md": "Hello!\nBye"
1622//         }),
1623//     )
1624//     .await;
1625//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1626
1627//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1628//     let diagnostic = lsp::Diagnostic {
1629//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1630//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1631//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1632//         ..Default::default()
1633//     };
1634
1635//     project.update(cx, |project, cx| {
1636//         project.lsp_store().update(cx, |lsp_store, cx| {
1637//             // Create some diagnostics
1638//             lsp_store
1639//                 .update_diagnostics(
1640//                     LanguageServerId(0),
1641//                     lsp::PublishDiagnosticsParams {
1642//                         uri: path_to_buffer_uri.clone(),
1643//                         diagnostics: vec![diagnostic],
1644//                         version: None,
1645//                     },
1646//                     None,
1647//                     language::DiagnosticSourceKind::Pushed,
1648//                     &[],
1649//                     cx,
1650//                 )
1651//                 .unwrap();
1652//         });
1653//     });
1654
1655//     let buffer = project
1656//         .update(cx, |project, cx| {
1657//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1658//             project.open_buffer(path, cx)
1659//         })
1660//         .await
1661//         .unwrap();
1662
1663//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1664//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1665
1666//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1667//         ep_store.request_prediction(&project, &buffer, position, cx)
1668//     });
1669
1670//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1671
1672//     assert_eq!(request.diagnostic_groups.len(), 1);
1673//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1674//         .unwrap();
1675//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1676//     assert_eq!(
1677//         value,
1678//         json!({
1679//             "entries": [{
1680//                 "range": {
1681//                     "start": 8,
1682//                     "end": 10
1683//                 },
1684//                 "diagnostic": {
1685//                     "source": null,
1686//                     "code": null,
1687//                     "code_description": null,
1688//                     "severity": 1,
1689//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1690//                     "markdown": null,
1691//                     "group_id": 0,
1692//                     "is_primary": true,
1693//                     "is_disk_based": false,
1694//                     "is_unnecessary": false,
1695//                     "source_kind": "Pushed",
1696//                     "data": null,
1697//                     "underline": true
1698//                 }
1699//             }],
1700//             "primary_ix": 0
1701//         })
1702//     );
1703// }
1704
1705// Generate a model response that would apply the given diff to the active file.
1706fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1707    let editable_range =
1708        zeta_prompt::excerpt_range_for_format(Default::default(), &request.input.excerpt_ranges).1;
1709    let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string();
1710    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1711
1712    PredictEditsV3Response {
1713        request_id: Uuid::new_v4().to_string(),
1714        editable_range,
1715        output: new_excerpt,
1716        model_version: None,
1717    }
1718}
1719
1720fn empty_response() -> PredictEditsV3Response {
1721    PredictEditsV3Response {
1722        request_id: Uuid::new_v4().to_string(),
1723        editable_range: 0..0,
1724        output: String::new(),
1725        model_version: None,
1726    }
1727}
1728
1729fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1730    zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1731}
1732
1733fn assert_no_predict_request_ready(
1734    requests: &mut mpsc::UnboundedReceiver<(
1735        PredictEditsV3Request,
1736        oneshot::Sender<PredictEditsV3Response>,
1737    )>,
1738) {
1739    if requests.next().now_or_never().flatten().is_some() {
1740        panic!("Unexpected prediction request while throttled.");
1741    }
1742}
1743
1744struct RequestChannels {
1745    predict: mpsc::UnboundedReceiver<(
1746        PredictEditsV3Request,
1747        oneshot::Sender<PredictEditsV3Response>,
1748    )>,
1749    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1750}
1751
1752fn init_test_with_fake_client(
1753    cx: &mut TestAppContext,
1754) -> (Entity<EditPredictionStore>, RequestChannels) {
1755    cx.update(move |cx| {
1756        let settings_store = SettingsStore::test(cx);
1757        cx.set_global(settings_store);
1758        zlog::init_test();
1759
1760        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1761        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1762
1763        let http_client = FakeHttpClient::create({
1764            move |req| {
1765                let uri = req.uri().path().to_string();
1766                let mut body = req.into_body();
1767                let predict_req_tx = predict_req_tx.clone();
1768                let reject_req_tx = reject_req_tx.clone();
1769                async move {
1770                    let resp = match uri.as_str() {
1771                        "/client/llm_tokens" => serde_json::to_string(&json!({
1772                            "token": "test"
1773                        }))
1774                        .unwrap(),
1775                        "/predict_edits/v3" => {
1776                            let mut buf = Vec::new();
1777                            body.read_to_end(&mut buf).await.ok();
1778                            let decompressed = zstd::decode_all(&buf[..]).unwrap();
1779                            let req = serde_json::from_slice(&decompressed).unwrap();
1780
1781                            let (res_tx, res_rx) = oneshot::channel();
1782                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1783                            serde_json::to_string(&res_rx.await?).unwrap()
1784                        }
1785                        "/predict_edits/reject" => {
1786                            let mut buf = Vec::new();
1787                            body.read_to_end(&mut buf).await.ok();
1788                            let req = serde_json::from_slice(&buf).unwrap();
1789
1790                            let (res_tx, res_rx) = oneshot::channel();
1791                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1792                            serde_json::to_string(&res_rx.await?).unwrap()
1793                        }
1794                        _ => {
1795                            panic!("Unexpected path: {}", uri)
1796                        }
1797                    };
1798
1799                    Ok(Response::builder().body(resp.into()).unwrap())
1800                }
1801            }
1802        });
1803
1804        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1805        client.cloud_client().set_credentials(1, "test".into());
1806
1807        language_model::init(client.clone(), cx);
1808
1809        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1810        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1811
1812        (
1813            ep_store,
1814            RequestChannels {
1815                predict: predict_req_rx,
1816                reject: reject_req_rx,
1817            },
1818        )
1819    })
1820}
1821
1822#[gpui::test]
1823async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1824    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1825    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1826        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1827    });
1828
1829    let edit_preview = cx
1830        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1831        .await;
1832
1833    let prediction = EditPrediction {
1834        edits,
1835        cursor_position: None,
1836        edit_preview,
1837        buffer: buffer.clone(),
1838        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1839        id: EditPredictionId("the-id".into()),
1840        inputs: ZetaPromptInput {
1841            events: Default::default(),
1842            related_files: Default::default(),
1843            cursor_path: Path::new("").into(),
1844            cursor_excerpt: "".into(),
1845            cursor_offset_in_excerpt: 0,
1846            excerpt_start_row: None,
1847            excerpt_ranges: Default::default(),
1848            experiment: None,
1849            in_open_source_repo: false,
1850            can_collect_data: false,
1851            repo_url: None,
1852        },
1853        buffer_snapshotted_at: Instant::now(),
1854        response_received_at: Instant::now(),
1855        model_version: None,
1856    };
1857
1858    cx.update(|cx| {
1859        assert_eq!(
1860            from_completion_edits(
1861                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1862                &buffer,
1863                cx
1864            ),
1865            vec![(2..5, "REM".into()), (9..11, "".into())]
1866        );
1867
1868        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1869        assert_eq!(
1870            from_completion_edits(
1871                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1872                &buffer,
1873                cx
1874            ),
1875            vec![(2..2, "REM".into()), (6..8, "".into())]
1876        );
1877
1878        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1879        assert_eq!(
1880            from_completion_edits(
1881                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1882                &buffer,
1883                cx
1884            ),
1885            vec![(2..5, "REM".into()), (9..11, "".into())]
1886        );
1887
1888        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1889        assert_eq!(
1890            from_completion_edits(
1891                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1892                &buffer,
1893                cx
1894            ),
1895            vec![(3..3, "EM".into()), (7..9, "".into())]
1896        );
1897
1898        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1899        assert_eq!(
1900            from_completion_edits(
1901                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1902                &buffer,
1903                cx
1904            ),
1905            vec![(4..4, "M".into()), (8..10, "".into())]
1906        );
1907
1908        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1909        assert_eq!(
1910            from_completion_edits(
1911                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1912                &buffer,
1913                cx
1914            ),
1915            vec![(9..11, "".into())]
1916        );
1917
1918        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1919        assert_eq!(
1920            from_completion_edits(
1921                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1922                &buffer,
1923                cx
1924            ),
1925            vec![(4..4, "M".into()), (8..10, "".into())]
1926        );
1927
1928        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1929        assert_eq!(
1930            from_completion_edits(
1931                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1932                &buffer,
1933                cx
1934            ),
1935            vec![(4..4, "M".into())]
1936        );
1937
1938        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1939        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1940    })
1941}
1942
1943#[gpui::test]
1944async fn test_clean_up_diff(cx: &mut TestAppContext) {
1945    init_test(cx);
1946
1947    assert_eq!(
1948        apply_edit_prediction(
1949            indoc! {"
1950                    fn main() {
1951                        let word_1 = \"lorem\";
1952                        let range = word.len()..word.len();
1953                    }
1954                "},
1955            indoc! {"
1956                    fn main() {
1957                        let word_1 = \"lorem\";
1958                        let range = word_1.len()..word_1.len();
1959                    }
1960                "},
1961            cx,
1962        )
1963        .await,
1964        indoc! {"
1965                fn main() {
1966                    let word_1 = \"lorem\";
1967                    let range = word_1.len()..word_1.len();
1968                }
1969            "},
1970    );
1971
1972    assert_eq!(
1973        apply_edit_prediction(
1974            indoc! {"
1975                    fn main() {
1976                        let story = \"the quick\"
1977                    }
1978                "},
1979            indoc! {"
1980                    fn main() {
1981                        let story = \"the quick brown fox jumps over the lazy dog\";
1982                    }
1983                "},
1984            cx,
1985        )
1986        .await,
1987        indoc! {"
1988                fn main() {
1989                    let story = \"the quick brown fox jumps over the lazy dog\";
1990                }
1991            "},
1992    );
1993}
1994
1995#[gpui::test]
1996async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1997    init_test(cx);
1998
1999    let buffer_content = "lorem\n";
2000    let completion_response = "lorem\nipsum\n";
2001
2002    assert_eq!(
2003        apply_edit_prediction(buffer_content, completion_response, cx).await,
2004        "lorem\nipsum\n"
2005    );
2006}
2007
2008#[gpui::test]
2009async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
2010    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
2011    // When the buffer ends without a trailing newline, but the model returns output
2012    // with a trailing newline, zeta2 should normalize both sides before diffing
2013    // so no spurious newline is inserted.
2014    let (ep_store, mut requests) = init_test_with_fake_client(cx);
2015    let fs = FakeFs::new(cx.executor());
2016
2017    // Single line buffer with no trailing newline
2018    fs.insert_tree(
2019        "/root",
2020        json!({
2021            "foo.txt": "hello"
2022        }),
2023    )
2024    .await;
2025    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2026
2027    let buffer = project
2028        .update(cx, |project, cx| {
2029            let path = project
2030                .find_project_path(path!("root/foo.txt"), cx)
2031                .unwrap();
2032            project.open_buffer(path, cx)
2033        })
2034        .await
2035        .unwrap();
2036
2037    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2038    let position = snapshot.anchor_before(language::Point::new(0, 5));
2039
2040    ep_store.update(cx, |ep_store, cx| {
2041        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2042    });
2043
2044    let (request, respond_tx) = requests.predict.next().await.unwrap();
2045
2046    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
2047    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
2048    let excerpt_length = request.input.cursor_excerpt.len();
2049    let response = PredictEditsV3Response {
2050        request_id: Uuid::new_v4().to_string(),
2051        output: "hello world\n".to_string(),
2052        editable_range: 0..excerpt_length,
2053        model_version: None,
2054    };
2055    respond_tx.send(response).unwrap();
2056
2057    cx.run_until_parked();
2058
2059    // The prediction should insert " world" without adding a newline
2060    ep_store.update(cx, |ep_store, cx| {
2061        let prediction = ep_store
2062            .prediction_at(&buffer, None, &project, cx)
2063            .expect("should have prediction");
2064        let edits: Vec<_> = prediction
2065            .edits
2066            .iter()
2067            .map(|(range, text)| {
2068                let snapshot = buffer.read(cx).snapshot();
2069                (range.to_offset(&snapshot), text.clone())
2070            })
2071            .collect();
2072        assert_eq!(edits, vec![(5..5, " world".into())]);
2073    });
2074}
2075
2076fn init_test(cx: &mut TestAppContext) {
2077    cx.update(|cx| {
2078        let settings_store = SettingsStore::test(cx);
2079        cx.set_global(settings_store);
2080    });
2081}
2082
2083async fn apply_edit_prediction(
2084    buffer_content: &str,
2085    completion_response: &str,
2086    cx: &mut TestAppContext,
2087) -> String {
2088    let fs = project::FakeFs::new(cx.executor());
2089    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2090    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2091    let (ep_store, response) = make_test_ep_store(&project, cx).await;
2092    *response.lock() = completion_response.to_string();
2093    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2094    buffer.update(cx, |buffer, cx| {
2095        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2096    });
2097    buffer.read_with(cx, |buffer, _| buffer.text())
2098}
2099
2100async fn run_edit_prediction(
2101    buffer: &Entity<Buffer>,
2102    project: &Entity<Project>,
2103    ep_store: &Entity<EditPredictionStore>,
2104    cx: &mut TestAppContext,
2105) -> EditPrediction {
2106    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2107    ep_store.update(cx, |ep_store, cx| {
2108        ep_store.register_buffer(buffer, &project, cx)
2109    });
2110    cx.background_executor.run_until_parked();
2111    let prediction_task = ep_store.update(cx, |ep_store, cx| {
2112        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2113    });
2114    prediction_task.await.unwrap().unwrap().prediction.unwrap()
2115}
2116
2117async fn make_test_ep_store(
2118    project: &Entity<Project>,
2119    cx: &mut TestAppContext,
2120) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
2121    let default_response = "hello world\n".to_string();
2122    let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
2123    let http_client = FakeHttpClient::create({
2124        let completion_response = completion_response.clone();
2125        let mut next_request_id = 0;
2126        move |req| {
2127            let completion_response = completion_response.clone();
2128            let method = req.method().clone();
2129            let uri = req.uri().path().to_string();
2130            let mut body = req.into_body();
2131            async move {
2132                match (method, uri.as_str()) {
2133                    (Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2134                        .status(200)
2135                        .body(
2136                            serde_json::to_string(&CreateLlmTokenResponse {
2137                                token: LlmToken("the-llm-token".to_string()),
2138                            })
2139                            .unwrap()
2140                            .into(),
2141                        )
2142                        .unwrap()),
2143                    (Method::POST, "/predict_edits/v3") => {
2144                        let mut buf = Vec::new();
2145                        body.read_to_end(&mut buf).await.ok();
2146                        let decompressed = zstd::decode_all(&buf[..]).unwrap();
2147                        let req: PredictEditsV3Request =
2148                            serde_json::from_slice(&decompressed).unwrap();
2149
2150                        next_request_id += 1;
2151                        Ok(http_client::Response::builder()
2152                            .status(200)
2153                            .body(
2154                                serde_json::to_string(&PredictEditsV3Response {
2155                                    request_id: format!("request-{next_request_id}"),
2156                                    editable_range: 0..req.input.cursor_excerpt.len(),
2157                                    output: completion_response.lock().clone(),
2158                                    model_version: None,
2159                                })
2160                                .unwrap()
2161                                .into(),
2162                            )
2163                            .unwrap())
2164                    }
2165                    _ => Ok(http_client::Response::builder()
2166                        .status(404)
2167                        .body("Not Found".to_string().into())
2168                        .unwrap()),
2169                }
2170            }
2171        }
2172    });
2173
2174    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2175    cx.update(|cx| {
2176        RefreshLlmTokenListener::register(client.clone(), cx);
2177    });
2178    let _server = FakeServer::for_client(42, &client, cx).await;
2179
2180    let ep_store = cx.new(|cx| {
2181        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2182        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2183
2184        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2185        for worktree in worktrees {
2186            let worktree_id = worktree.read(cx).id();
2187            ep_store
2188                .get_or_init_project(project, cx)
2189                .license_detection_watchers
2190                .entry(worktree_id)
2191                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2192        }
2193
2194        ep_store
2195    });
2196
2197    (ep_store, completion_response)
2198}
2199
2200fn to_completion_edits(
2201    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2202    buffer: &Entity<Buffer>,
2203    cx: &App,
2204) -> Vec<(Range<Anchor>, Arc<str>)> {
2205    let buffer = buffer.read(cx);
2206    iterator
2207        .into_iter()
2208        .map(|(range, text)| {
2209            (
2210                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2211                text,
2212            )
2213        })
2214        .collect()
2215}
2216
2217fn from_completion_edits(
2218    editor_edits: &[(Range<Anchor>, Arc<str>)],
2219    buffer: &Entity<Buffer>,
2220    cx: &App,
2221) -> Vec<(Range<usize>, Arc<str>)> {
2222    let buffer = buffer.read(cx);
2223    editor_edits
2224        .iter()
2225        .map(|(range, text)| {
2226            (
2227                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2228                text.clone(),
2229            )
2230        })
2231        .collect()
2232}
2233
2234#[gpui::test]
2235async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2236    init_test(cx);
2237
2238    let fs = FakeFs::new(cx.executor());
2239    fs.insert_tree(
2240        "/project",
2241        serde_json::json!({
2242            "main.rs": "fn main() {\n    \n}\n"
2243        }),
2244    )
2245    .await;
2246
2247    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2248
2249    let http_client = FakeHttpClient::create(|_req| async move {
2250        Ok(gpui::http_client::Response::builder()
2251            .status(401)
2252            .body("Unauthorized".into())
2253            .unwrap())
2254    });
2255
2256    let client =
2257        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2258    cx.update(|cx| {
2259        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2260    });
2261
2262    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2263
2264    let buffer = project
2265        .update(cx, |project, cx| {
2266            let path = project
2267                .find_project_path(path!("/project/main.rs"), cx)
2268                .unwrap();
2269            project.open_buffer(path, cx)
2270        })
2271        .await
2272        .unwrap();
2273
2274    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2275    ep_store.update(cx, |ep_store, cx| {
2276        ep_store.register_buffer(&buffer, &project, cx)
2277    });
2278    cx.background_executor.run_until_parked();
2279
2280    let completion_task = ep_store.update(cx, |ep_store, cx| {
2281        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2282        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2283    });
2284
2285    let result = completion_task.await;
2286    assert!(
2287        result.is_err(),
2288        "Without authentication and without custom URL, prediction should fail"
2289    );
2290}
2291
2292#[gpui::test]
2293fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2294    let buffer = cx.new(|cx| {
2295        Buffer::local(
2296            indoc! {"
2297                zero
2298                one
2299                two
2300                three
2301                four
2302                five
2303                six
2304                seven
2305                eight
2306                nine
2307                ten
2308                eleven
2309                twelve
2310                thirteen
2311                fourteen
2312                fifteen
2313                sixteen
2314                seventeen
2315                eighteen
2316                nineteen
2317                twenty
2318                twenty-one
2319                twenty-two
2320                twenty-three
2321                twenty-four
2322            "},
2323            cx,
2324        )
2325    });
2326
2327    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2328
2329    buffer.update(cx, |buffer, cx| {
2330        let point = Point::new(12, 0);
2331        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2332        let point = Point::new(8, 0);
2333        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2334    });
2335
2336    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2337
2338    let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2339
2340    assert_eq!(
2341        diff,
2342        indoc! {"
2343            @@ -6,10 +6,12 @@
2344             five
2345             six
2346             seven
2347            +FIRST INSERTION
2348             eight
2349             nine
2350             ten
2351             eleven
2352            +SECOND INSERTION
2353             twelve
2354             thirteen
2355             fourteen
2356            "}
2357    );
2358}
2359
2360#[gpui::test]
2361async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
2362    fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
2363        let collab_replica = clock::ReplicaId::new(10);
2364        let anchor = buffer.read_with(cx, |buffer, _| {
2365            buffer.snapshot().anchor_before(Point::new(row, 0))
2366        });
2367        let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
2368            id: 1,
2369            start: anchor,
2370            end: anchor,
2371            reversed: false,
2372            goal: SelectionGoal::None,
2373        }]);
2374        buffer.update(cx, |buffer, cx| {
2375            buffer.apply_ops(
2376                [Operation::UpdateSelections {
2377                    selections,
2378                    lamport_timestamp: clock::Lamport {
2379                        replica_id: collab_replica,
2380                        value: 1,
2381                    },
2382                    line_mode: false,
2383                    cursor_shape: CursorShape::Bar,
2384                }],
2385                cx,
2386            );
2387        });
2388    }
2389
2390    fn publish_diagnostics(
2391        uri_path: &'static str,
2392        rows: &[u32],
2393        project: &Entity<Project>,
2394        cx: &mut TestAppContext,
2395    ) {
2396        let diagnostics: Vec<_> = rows
2397            .iter()
2398            .map(|&row| lsp::Diagnostic {
2399                range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
2400                severity: Some(lsp::DiagnosticSeverity::ERROR),
2401                message: format!("error at row {row}"),
2402                ..Default::default()
2403            })
2404            .collect();
2405        project.update(cx, |project, cx| {
2406            project.lsp_store().update(cx, |lsp_store, cx| {
2407                lsp_store
2408                    .update_diagnostics(
2409                        LanguageServerId(0),
2410                        lsp::PublishDiagnosticsParams {
2411                            uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
2412                            diagnostics,
2413                            version: None,
2414                        },
2415                        None,
2416                        language::DiagnosticSourceKind::Pushed,
2417                        &[],
2418                        cx,
2419                    )
2420                    .expect("failed to update diagnostics");
2421            });
2422        });
2423    }
2424
2425    init_test(cx);
2426
2427    let mut lines = String::new();
2428    for i in 0..60 {
2429        lines.push_str(&format!("line {i}\n"));
2430    }
2431
2432    let fs = FakeFs::new(cx.executor());
2433    fs.insert_tree(
2434        "/root",
2435        json!({
2436            "active.txt": lines,
2437            "collab_file.txt": "error here\nsecond line\n",
2438            "free_file.txt": "another error\nsecond line\n",
2439        }),
2440    )
2441    .await;
2442    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2443
2444    let active_buffer = project
2445        .update(cx, |project, cx| {
2446            let path = project
2447                .find_project_path(path!("/root/active.txt"), cx)
2448                .expect("active.txt not found");
2449            project.set_active_path(Some(path.clone()), cx);
2450            project.open_buffer(path, cx)
2451        })
2452        .await
2453        .expect("failed to open active buffer");
2454
2455    set_collaborator_cursor(&active_buffer, 5, cx);
2456
2457    publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
2458
2459    cx.run_until_parked();
2460
2461    let cursor_point = Point::new(25, 0);
2462    let empty_search_range: Range<Point> = Default::default();
2463
2464    let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2465    let result = EditPredictionStore::next_diagnostic_location(
2466        active_buffer.clone(),
2467        &snapshot,
2468        empty_search_range.clone(),
2469        cursor_point,
2470        &project,
2471        &mut cx.to_async(),
2472    )
2473    .await
2474    .expect("next_diagnostic_location failed");
2475
2476    let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
2477    assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
2478    let result_row = result_buffer.read_with(cx, |buffer, _| {
2479        result_anchor.to_point(&buffer.snapshot()).row
2480    });
2481    assert_ne!(
2482        result_row, 3,
2483        "row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
2484    );
2485    assert!(
2486        result_row == 25 || result_row == 50,
2487        "expected row 25 or 50, got {result_row}"
2488    );
2489
2490    let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2491    let near_cursor_point = Point::new(4, 0);
2492    let result_near = EditPredictionStore::next_diagnostic_location(
2493        active_buffer.clone(),
2494        &snapshot_near,
2495        empty_search_range.clone(),
2496        near_cursor_point,
2497        &project,
2498        &mut cx.to_async(),
2499    )
2500    .await
2501    .expect("next_diagnostic_location failed");
2502
2503    let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
2504    let near_row =
2505        active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
2506    assert_eq!(
2507        near_row, 3,
2508        "row 3 should be included when local cursor (row 4) is also near the collaborator"
2509    );
2510
2511    let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2512    let far_cursor_point = Point::new(50, 0);
2513    let result_far = EditPredictionStore::next_diagnostic_location(
2514        active_buffer.clone(),
2515        &snapshot_far,
2516        empty_search_range.clone(),
2517        far_cursor_point,
2518        &project,
2519        &mut cx.to_async(),
2520    )
2521    .await
2522    .expect("next_diagnostic_location failed");
2523
2524    let (_, far_anchor) = result_far.expect("expected a diagnostic location");
2525    let far_row =
2526        active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
2527    assert_eq!(
2528        far_row, 50,
2529        "row 50 is near local cursor (row 50) and far from collaborator, should be picked"
2530    );
2531
2532    publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
2533    publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
2534    cx.run_until_parked();
2535
2536    let collab_buffer = project
2537        .update(cx, |project, cx| {
2538            let path = project
2539                .find_project_path(path!("/root/collab_file.txt"), cx)
2540                .expect("collab_file.txt not found");
2541            project.open_buffer(path, cx)
2542        })
2543        .await
2544        .expect("failed to open collab buffer");
2545
2546    set_collaborator_cursor(&collab_buffer, 0, cx);
2547    cx.run_until_parked();
2548
2549    let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
2550    let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2551    let result_cross = EditPredictionStore::next_diagnostic_location(
2552        active_buffer.clone(),
2553        &snapshot_cross,
2554        no_same_file_search_range,
2555        Point::new(0, 0),
2556        &project,
2557        &mut cx.to_async(),
2558    )
2559    .await
2560    .expect("cross-file next_diagnostic_location failed");
2561
2562    let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
2563    let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
2564        buffer
2565            .file()
2566            .expect("buffer should have a file")
2567            .full_path(cx)
2568    });
2569    assert_eq!(
2570        cross_path,
2571        Path::new(path!("root/free_file.txt")),
2572        "should skip collab_file.txt (has collaborator) and pick free_file.txt"
2573    );
2574}
2575
2576#[gpui::test]
2577async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
2578    let (ep_store, _requests) = init_test_with_fake_client(cx);
2579    let fs = FakeFs::new(cx.executor());
2580
2581    // Buffer with two clearly separated regions:
2582    //   Region A = lines 0-9   (offsets 0..50)
2583    //   Region B = lines 20-29 (offsets 105..155)
2584    // A big gap in between so edits in one region never overlap the other.
2585    let mut content = String::new();
2586    for i in 0..30 {
2587        content.push_str(&format!("line {i:02}\n"));
2588    }
2589
2590    fs.insert_tree(
2591        "/root",
2592        json!({
2593            "foo.md": content.clone()
2594        }),
2595    )
2596    .await;
2597    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2598
2599    let buffer = project
2600        .update(cx, |project, cx| {
2601            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2602            project.open_buffer(path, cx)
2603        })
2604        .await
2605        .unwrap();
2606
2607    let settled_events: Arc<Mutex<Vec<(EditPredictionId, String)>>> =
2608        Arc::new(Mutex::new(Vec::new()));
2609
2610    ep_store.update(cx, |ep_store, cx| {
2611        ep_store.register_buffer(&buffer, &project, cx);
2612
2613        let settled_events = settled_events.clone();
2614        ep_store.settled_event_callback = Some(Box::new(move |id, text| {
2615            settled_events.lock().push((id, text));
2616        }));
2617    });
2618
2619    // --- Phase 1: edit in region A and enqueue prediction A ---
2620
2621    buffer.update(cx, |buffer, cx| {
2622        // Edit at the start of line 0.
2623        buffer.edit(vec![(0..0, "ADDED ")], None, cx);
2624    });
2625    cx.run_until_parked();
2626
2627    let snapshot_a = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2628
2629    // Region A: first 10 lines of the buffer.
2630    let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0));
2631    ep_store.update(cx, |ep_store, cx| {
2632        ep_store.enqueue_settled_prediction(
2633            EditPredictionId("prediction-a".into()),
2634            &project,
2635            &buffer,
2636            &snapshot_a,
2637            editable_region_a,
2638            cx,
2639        );
2640    });
2641
2642    // --- Phase 2: repeatedly edit in region A to keep it unsettled ---
2643
2644    // Let the worker process the channel message before we start advancing.
2645    cx.run_until_parked();
2646
2647    let mut region_a_edit_offset = 5;
2648    for _ in 0..3 {
2649        // Edit inside region A (not at the boundary) so `last_edit_at` is
2650        // updated before the worker's next wake.
2651        buffer.update(cx, |buffer, cx| {
2652            buffer.edit(
2653                vec![(region_a_edit_offset..region_a_edit_offset, "x")],
2654                None,
2655                cx,
2656            );
2657        });
2658        region_a_edit_offset += 1;
2659        cx.run_until_parked();
2660
2661        cx.executor()
2662            .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 2);
2663        cx.run_until_parked();
2664        assert!(
2665            settled_events.lock().is_empty(),
2666            "no settled events should fire while region A is still being edited"
2667        );
2668    }
2669
2670    // Still nothing settled.
2671    assert!(settled_events.lock().is_empty());
2672
2673    // --- Phase 3: edit in distinct region B, enqueue prediction B ---
2674    // Advance a small amount so B's quiescence window starts later than A's,
2675    // but not so much that A settles (A's last edit was at the start of
2676    // iteration 3, and it needs a full Q to settle).
2677    cx.executor()
2678        .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2679    cx.run_until_parked();
2680    assert!(settled_events.lock().is_empty());
2681
2682    let snapshot_b = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2683    let line_20_offset = snapshot_b.point_to_offset(Point::new(20, 0));
2684
2685    buffer.update(cx, |buffer, cx| {
2686        buffer.edit(vec![(line_20_offset..line_20_offset, "NEW ")], None, cx);
2687    });
2688    cx.run_until_parked();
2689
2690    let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2691    let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0));
2692    ep_store.update(cx, |ep_store, cx| {
2693        ep_store.enqueue_settled_prediction(
2694            EditPredictionId("prediction-b".into()),
2695            &project,
2696            &buffer,
2697            &snapshot_b2,
2698            editable_region_b,
2699            cx,
2700        );
2701    });
2702
2703    cx.run_until_parked();
2704    assert!(
2705        settled_events.lock().is_empty(),
2706        "neither prediction should have settled yet"
2707    );
2708
2709    // --- Phase 4: let enough time pass for region A to settle ---
2710    // A's last edit was at T_a (during the last loop iteration). The worker is
2711    // sleeping until T_a + Q. We advance just enough to reach that wake time
2712    // (Q/4 since we already advanced Q/4 in phase 3 on top of the loop's
2713    // 3*Q/2). At that point A has been quiet for Q and settles, but B was
2714    // enqueued only Q/4 ago and stays pending.
2715    cx.executor()
2716        .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2717    cx.run_until_parked();
2718
2719    {
2720        let events = settled_events.lock().clone();
2721        assert_eq!(
2722            events.len(),
2723            1,
2724            "only prediction A should have settled, got: {events:?}"
2725        );
2726        assert_eq!(events[0].0, EditPredictionId("prediction-a".into()));
2727    }
2728
2729    // --- Phase 5: let more time pass for region B to settle ---
2730    // B's last edit was Q/4 before A settled. The worker rescheduled to
2731    // B's last_edit_at + Q, which is 3Q/4 from now.
2732    cx.executor()
2733        .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE * 3 / 4);
2734    cx.run_until_parked();
2735
2736    {
2737        let events = settled_events.lock().clone();
2738        assert_eq!(
2739            events.len(),
2740            2,
2741            "both predictions should have settled, got: {events:?}"
2742        );
2743        assert_eq!(events[1].0, EditPredictionId("prediction-b".into()));
2744    }
2745}
2746
2747#[ctor::ctor]
2748fn init_logger() {
2749    zlog::init_test();
2750}