edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
   3use client::{UserStore, test::FakeServer};
   4use clock::FakeSystemClock;
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
   8    predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
   9};
  10use futures::{
  11    AsyncReadExt, FutureExt, StreamExt,
  12    channel::{mpsc, oneshot},
  13};
  14use gpui::App;
  15use gpui::{
  16    Entity, TestAppContext,
  17    http_client::{FakeHttpClient, Response},
  18};
  19use indoc::indoc;
  20use language::{Anchor, Buffer, CursorShape, Operation, Point, Selection, SelectionGoal};
  21use lsp::LanguageServerId;
  22use parking_lot::Mutex;
  23use pretty_assertions::{assert_eq, assert_matches};
  24use project::{FakeFs, Project};
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{path::Path, sync::Arc, time::Duration};
  28use util::path;
  29use uuid::Uuid;
  30use zeta_prompt::ZetaPromptInput;
  31
  32use crate::{
  33    BufferEditPrediction, EDIT_PREDICTION_SETTLED_QUIESCENCE, EditPredictionId,
  34    EditPredictionStore, REJECT_REQUEST_DEBOUNCE,
  35};
  36
  37#[gpui::test]
  38async fn test_current_state(cx: &mut TestAppContext) {
  39    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  40    let fs = FakeFs::new(cx.executor());
  41    fs.insert_tree(
  42        "/root",
  43        json!({
  44            "1.txt": "Hello!\nHow\nBye\n",
  45            "2.txt": "Hola!\nComo\nAdios\n"
  46        }),
  47    )
  48    .await;
  49    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  50
  51    let buffer1 = project
  52        .update(cx, |project, cx| {
  53            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  54            project.set_active_path(Some(path.clone()), cx);
  55            project.open_buffer(path, cx)
  56        })
  57        .await
  58        .unwrap();
  59    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  60    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  61
  62    ep_store.update(cx, |ep_store, cx| {
  63        ep_store.register_project(&project, cx);
  64        ep_store.register_buffer(&buffer1, &project, cx);
  65    });
  66
  67    // Prediction for current file
  68
  69    ep_store.update(cx, |ep_store, cx| {
  70        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  71    });
  72    let (request, respond_tx) = requests.predict.next().await.unwrap();
  73
  74    respond_tx
  75        .send(model_response(
  76            &request,
  77            indoc! {r"
  78                --- a/root/1.txt
  79                +++ b/root/1.txt
  80                @@ ... @@
  81                 Hello!
  82                -How
  83                +How are you?
  84                 Bye
  85            "},
  86        ))
  87        .unwrap();
  88
  89    cx.run_until_parked();
  90
  91    ep_store.update(cx, |ep_store, cx| {
  92        let prediction = ep_store
  93            .prediction_at(&buffer1, None, &project, cx)
  94            .unwrap();
  95        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  96    });
  97
  98    ep_store.update(cx, |ep_store, cx| {
  99        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
 100    });
 101
 102    // Prediction for diagnostic in another file
 103
 104    let diagnostic = lsp::Diagnostic {
 105        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 106        severity: Some(lsp::DiagnosticSeverity::ERROR),
 107        message: "Sentence is incomplete".to_string(),
 108        ..Default::default()
 109    };
 110
 111    project.update(cx, |project, cx| {
 112        project.lsp_store().update(cx, |lsp_store, cx| {
 113            lsp_store
 114                .update_diagnostics(
 115                    LanguageServerId(0),
 116                    lsp::PublishDiagnosticsParams {
 117                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 118                        diagnostics: vec![diagnostic],
 119                        version: None,
 120                    },
 121                    None,
 122                    language::DiagnosticSourceKind::Pushed,
 123                    &[],
 124                    cx,
 125                )
 126                .unwrap();
 127        });
 128    });
 129
 130    let (request, respond_tx) = requests.predict.next().await.unwrap();
 131    respond_tx
 132        .send(model_response(
 133            &request,
 134            indoc! {r#"
 135                --- a/root/2.txt
 136                +++ b/root/2.txt
 137                @@ ... @@
 138                 Hola!
 139                -Como
 140                +Como estas?
 141                 Adios
 142            "#},
 143        ))
 144        .unwrap();
 145    cx.run_until_parked();
 146
 147    ep_store.update(cx, |ep_store, cx| {
 148        let prediction = ep_store
 149            .prediction_at(&buffer1, None, &project, cx)
 150            .unwrap();
 151        assert_matches!(
 152            prediction,
 153            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 154        );
 155    });
 156
 157    let buffer2 = project
 158        .update(cx, |project, cx| {
 159            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 160            project.open_buffer(path, cx)
 161        })
 162        .await
 163        .unwrap();
 164
 165    ep_store.update(cx, |ep_store, cx| {
 166        let prediction = ep_store
 167            .prediction_at(&buffer2, None, &project, cx)
 168            .unwrap();
 169        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 170    });
 171}
 172
 173#[gpui::test]
 174async fn test_simple_request(cx: &mut TestAppContext) {
 175    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 176    let fs = FakeFs::new(cx.executor());
 177    fs.insert_tree(
 178        "/root",
 179        json!({
 180            "foo.md":  "Hello!\nHow\nBye\n"
 181        }),
 182    )
 183    .await;
 184    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 185
 186    let buffer = project
 187        .update(cx, |project, cx| {
 188            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 189            project.open_buffer(path, cx)
 190        })
 191        .await
 192        .unwrap();
 193    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 194    let position = snapshot.anchor_before(language::Point::new(1, 3));
 195
 196    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 197        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 198    });
 199
 200    let (request, respond_tx) = requests.predict.next().await.unwrap();
 201
 202    // TODO Put back when we have a structured request again
 203    // assert_eq!(
 204    //     request.excerpt_path.as_ref(),
 205    //     Path::new(path!("root/foo.md"))
 206    // );
 207    // assert_eq!(
 208    //     request.cursor_point,
 209    //     Point {
 210    //         line: Line(1),
 211    //         column: 3
 212    //     }
 213    // );
 214
 215    respond_tx
 216        .send(model_response(
 217            &request,
 218            indoc! { r"
 219                --- a/root/foo.md
 220                +++ b/root/foo.md
 221                @@ ... @@
 222                 Hello!
 223                -How
 224                +How are you?
 225                 Bye
 226            "},
 227        ))
 228        .unwrap();
 229
 230    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 231
 232    assert_eq!(prediction.edits.len(), 1);
 233    assert_eq!(
 234        prediction.edits[0].0.to_point(&snapshot).start,
 235        language::Point::new(1, 3)
 236    );
 237    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 238}
 239
 240#[gpui::test]
 241async fn test_request_events(cx: &mut TestAppContext) {
 242    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 243    let fs = FakeFs::new(cx.executor());
 244    fs.insert_tree(
 245        "/root",
 246        json!({
 247            "foo.md": "Hello!\n\nBye\n"
 248        }),
 249    )
 250    .await;
 251    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 252
 253    let buffer = project
 254        .update(cx, |project, cx| {
 255            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 256            project.open_buffer(path, cx)
 257        })
 258        .await
 259        .unwrap();
 260
 261    ep_store.update(cx, |ep_store, cx| {
 262        ep_store.register_buffer(&buffer, &project, cx);
 263    });
 264
 265    buffer.update(cx, |buffer, cx| {
 266        buffer.edit(vec![(7..7, "How")], None, cx);
 267    });
 268
 269    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 270    let position = snapshot.anchor_before(language::Point::new(1, 3));
 271
 272    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 273        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 274    });
 275
 276    let (request, respond_tx) = requests.predict.next().await.unwrap();
 277
 278    let prompt = prompt_from_request(&request);
 279    assert!(
 280        prompt.contains(indoc! {"
 281        --- a/root/foo.md
 282        +++ b/root/foo.md
 283        @@ -1,3 +1,3 @@
 284         Hello!
 285        -
 286        +How
 287         Bye
 288    "}),
 289        "{prompt}"
 290    );
 291
 292    respond_tx
 293        .send(model_response(
 294            &request,
 295            indoc! {r#"
 296                --- a/root/foo.md
 297                +++ b/root/foo.md
 298                @@ ... @@
 299                 Hello!
 300                -How
 301                +How are you?
 302                 Bye
 303        "#},
 304        ))
 305        .unwrap();
 306
 307    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 308
 309    assert_eq!(prediction.edits.len(), 1);
 310    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 311}
 312
 313#[gpui::test]
 314async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 315    let (ep_store, _requests) = init_test_with_fake_client(cx);
 316    let fs = FakeFs::new(cx.executor());
 317    fs.insert_tree(
 318        "/root",
 319        json!({
 320            "foo.md": "Hello!\n\nBye\n"
 321        }),
 322    )
 323    .await;
 324    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 325
 326    let buffer = project
 327        .update(cx, |project, cx| {
 328            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 329            project.open_buffer(path, cx)
 330        })
 331        .await
 332        .unwrap();
 333
 334    ep_store.update(cx, |ep_store, cx| {
 335        ep_store.register_buffer(&buffer, &project, cx);
 336    });
 337
 338    // First burst: insert "How"
 339    buffer.update(cx, |buffer, cx| {
 340        buffer.edit(vec![(7..7, "How")], None, cx);
 341    });
 342
 343    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 344    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 345    cx.run_until_parked();
 346
 347    // Second burst: append " are you?" immediately after "How" on the same line.
 348    //
 349    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 350    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 351    buffer.update(cx, |buffer, cx| {
 352        buffer.edit(vec![(10..10, " are you?")], None, cx);
 353    });
 354
 355    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 356    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 357    buffer.update(cx, |buffer, cx| {
 358        buffer.edit(vec![(19..19, "!")], None, cx);
 359    });
 360
 361    // With time-based splitting, there are two distinct events.
 362    let events = ep_store.update(cx, |ep_store, cx| {
 363        ep_store.edit_history_for_project(&project, cx)
 364    });
 365    assert_eq!(events.len(), 2);
 366    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 367    assert_eq!(
 368        diff.as_str(),
 369        indoc! {"
 370            @@ -1,3 +1,3 @@
 371             Hello!
 372            -
 373            +How
 374             Bye
 375        "}
 376    );
 377
 378    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 379    assert_eq!(
 380        diff.as_str(),
 381        indoc! {"
 382            @@ -1,3 +1,3 @@
 383             Hello!
 384            -How
 385            +How are you?!
 386             Bye
 387        "}
 388    );
 389}
 390
 391#[gpui::test]
 392async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
 393    let (ep_store, _requests) = init_test_with_fake_client(cx);
 394    let fs = FakeFs::new(cx.executor());
 395
 396    // Create a file with 30 lines to test line-based coalescing
 397    let content = (1..=30)
 398        .map(|i| format!("Line {}\n", i))
 399        .collect::<String>();
 400    fs.insert_tree(
 401        "/root",
 402        json!({
 403            "foo.md": content
 404        }),
 405    )
 406    .await;
 407    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 408
 409    let buffer = project
 410        .update(cx, |project, cx| {
 411            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 412            project.open_buffer(path, cx)
 413        })
 414        .await
 415        .unwrap();
 416
 417    ep_store.update(cx, |ep_store, cx| {
 418        ep_store.register_buffer(&buffer, &project, cx);
 419    });
 420
 421    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 422    buffer.update(cx, |buffer, cx| {
 423        let start = Point::new(10, 0).to_offset(buffer);
 424        let end = Point::new(13, 0).to_offset(buffer);
 425        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 426    });
 427
 428    let events = ep_store.update(cx, |ep_store, cx| {
 429        ep_store.edit_history_for_project(&project, cx)
 430    });
 431    assert_eq!(
 432        render_events(&events),
 433        indoc! {"
 434            @@ -8,9 +8,8 @@
 435             Line 8
 436             Line 9
 437             Line 10
 438            -Line 11
 439            -Line 12
 440            -Line 13
 441            +Middle A
 442            +Middle B
 443             Line 14
 444             Line 15
 445             Line 16
 446        "},
 447        "After first edit"
 448    );
 449
 450    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 451    // This tests that coalescing considers the START of the existing range
 452    buffer.update(cx, |buffer, cx| {
 453        let offset = Point::new(5, 0).to_offset(buffer);
 454        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 455    });
 456
 457    let events = ep_store.update(cx, |ep_store, cx| {
 458        ep_store.edit_history_for_project(&project, cx)
 459    });
 460    assert_eq!(
 461        render_events(&events),
 462        indoc! {"
 463            @@ -3,14 +3,14 @@
 464             Line 3
 465             Line 4
 466             Line 5
 467            +Above
 468             Line 6
 469             Line 7
 470             Line 8
 471             Line 9
 472             Line 10
 473            -Line 11
 474            -Line 12
 475            -Line 13
 476            +Middle A
 477            +Middle B
 478             Line 14
 479             Line 15
 480             Line 16
 481        "},
 482        "After inserting above (should coalesce)"
 483    );
 484
 485    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 486    // This tests that coalescing considers the END of the existing range
 487    buffer.update(cx, |buffer, cx| {
 488        let offset = Point::new(14, 0).to_offset(buffer);
 489        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 490    });
 491
 492    let events = ep_store.update(cx, |ep_store, cx| {
 493        ep_store.edit_history_for_project(&project, cx)
 494    });
 495    assert_eq!(
 496        render_events(&events),
 497        indoc! {"
 498            @@ -3,15 +3,16 @@
 499             Line 3
 500             Line 4
 501             Line 5
 502            +Above
 503             Line 6
 504             Line 7
 505             Line 8
 506             Line 9
 507             Line 10
 508            -Line 11
 509            -Line 12
 510            -Line 13
 511            +Middle A
 512            +Middle B
 513             Line 14
 514            +Below
 515             Line 15
 516             Line 16
 517             Line 17
 518        "},
 519        "After inserting below (should coalesce)"
 520    );
 521
 522    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 523    // This should NOT coalesce - creates a new event
 524    buffer.update(cx, |buffer, cx| {
 525        let offset = Point::new(25, 0).to_offset(buffer);
 526        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 527    });
 528
 529    let events = ep_store.update(cx, |ep_store, cx| {
 530        ep_store.edit_history_for_project(&project, cx)
 531    });
 532    assert_eq!(
 533        render_events(&events),
 534        indoc! {"
 535            @@ -3,15 +3,16 @@
 536             Line 3
 537             Line 4
 538             Line 5
 539            +Above
 540             Line 6
 541             Line 7
 542             Line 8
 543             Line 9
 544             Line 10
 545            -Line 11
 546            -Line 12
 547            -Line 13
 548            +Middle A
 549            +Middle B
 550             Line 14
 551            +Below
 552             Line 15
 553             Line 16
 554             Line 17
 555
 556            ---
 557            @@ -23,6 +23,7 @@
 558             Line 22
 559             Line 23
 560             Line 24
 561            +Far below
 562             Line 25
 563             Line 26
 564             Line 27
 565        "},
 566        "After inserting far below (should NOT coalesce)"
 567    );
 568}
 569
 570fn render_events(events: &[StoredEvent]) -> String {
 571    events
 572        .iter()
 573        .map(|e| {
 574            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 575            diff.as_str()
 576        })
 577        .collect::<Vec<_>>()
 578        .join("\n---\n")
 579}
 580
 581fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
 582    events
 583        .iter()
 584        .map(|e| {
 585            let zeta_prompt::Event::BufferChange {
 586                diff, predicted, ..
 587            } = e.event.as_ref();
 588            let prefix = if *predicted { "predicted" } else { "manual" };
 589            format!("{}\n{}", prefix, diff)
 590        })
 591        .collect()
 592}
 593
 594#[gpui::test]
 595async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
 596    let (ep_store, _requests) = init_test_with_fake_client(cx);
 597    let fs = FakeFs::new(cx.executor());
 598    fs.insert_tree(
 599        "/root",
 600        json!({
 601            "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
 602        }),
 603    )
 604    .await;
 605    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 606
 607    let buffer = project
 608        .update(cx, |project, cx| {
 609            let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
 610            project.open_buffer(path, cx)
 611        })
 612        .await
 613        .unwrap();
 614
 615    ep_store.update(cx, |ep_store, cx| {
 616        ep_store.register_buffer(&buffer, &project, cx);
 617    });
 618
 619    // Case 1: Manual edits have `predicted` set to false.
 620    buffer.update(cx, |buffer, cx| {
 621        buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
 622    });
 623
 624    let events = ep_store.update(cx, |ep_store, cx| {
 625        ep_store.edit_history_for_project(&project, cx)
 626    });
 627
 628    assert_eq!(
 629        render_events_with_predicted(&events),
 630        vec![indoc! {"
 631            manual
 632            @@ -1,4 +1,4 @@
 633            -line 0
 634            +LINE ZERO
 635             line 1
 636             line 2
 637             line 3
 638        "}]
 639    );
 640
 641    // Case 2: Multiple successive manual edits near each other are merged into one
 642    // event with `predicted` set to false.
 643    buffer.update(cx, |buffer, cx| {
 644        let offset = Point::new(1, 0).to_offset(buffer);
 645        let end = Point::new(1, 6).to_offset(buffer);
 646        buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
 647    });
 648
 649    let events = ep_store.update(cx, |ep_store, cx| {
 650        ep_store.edit_history_for_project(&project, cx)
 651    });
 652    assert_eq!(
 653        render_events_with_predicted(&events),
 654        vec![indoc! {"
 655            manual
 656            @@ -1,5 +1,5 @@
 657            -line 0
 658            -line 1
 659            +LINE ZERO
 660            +LINE ONE
 661             line 2
 662             line 3
 663             line 4
 664        "}]
 665    );
 666
 667    // Case 3: Accepted predictions have `predicted` set to true.
 668    // Case 5: A manual edit that follows a predicted edit is not merged with the
 669    // predicted edit, even if it is nearby.
 670    ep_store.update(cx, |ep_store, cx| {
 671        buffer.update(cx, |buffer, cx| {
 672            let offset = Point::new(2, 0).to_offset(buffer);
 673            let end = Point::new(2, 6).to_offset(buffer);
 674            buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
 675        });
 676        ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
 677    });
 678
 679    let events = ep_store.update(cx, |ep_store, cx| {
 680        ep_store.edit_history_for_project(&project, cx)
 681    });
 682    assert_eq!(
 683        render_events_with_predicted(&events),
 684        vec![
 685            indoc! {"
 686                manual
 687                @@ -1,5 +1,5 @@
 688                -line 0
 689                -line 1
 690                +LINE ZERO
 691                +LINE ONE
 692                 line 2
 693                 line 3
 694                 line 4
 695            "},
 696            indoc! {"
 697                predicted
 698                @@ -1,6 +1,6 @@
 699                 LINE ZERO
 700                 LINE ONE
 701                -line 2
 702                +LINE TWO
 703                 line 3
 704                 line 4
 705                 line 5
 706            "}
 707        ]
 708    );
 709
 710    // Case 4: Multiple successive accepted predictions near each other are merged
 711    // into one event with `predicted` set to true.
 712    ep_store.update(cx, |ep_store, cx| {
 713        buffer.update(cx, |buffer, cx| {
 714            let offset = Point::new(3, 0).to_offset(buffer);
 715            let end = Point::new(3, 6).to_offset(buffer);
 716            buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
 717        });
 718        ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
 719    });
 720
 721    let events = ep_store.update(cx, |ep_store, cx| {
 722        ep_store.edit_history_for_project(&project, cx)
 723    });
 724    assert_eq!(
 725        render_events_with_predicted(&events),
 726        vec![
 727            indoc! {"
 728                manual
 729                @@ -1,5 +1,5 @@
 730                -line 0
 731                -line 1
 732                +LINE ZERO
 733                +LINE ONE
 734                 line 2
 735                 line 3
 736                 line 4
 737            "},
 738            indoc! {"
 739                predicted
 740                @@ -1,7 +1,7 @@
 741                 LINE ZERO
 742                 LINE ONE
 743                -line 2
 744                -line 3
 745                +LINE TWO
 746                +LINE THREE
 747                 line 4
 748                 line 5
 749                 line 6
 750            "}
 751        ]
 752    );
 753
 754    // Case 5 (continued): A manual edit that follows a predicted edit is not merged
 755    // with the predicted edit, even if it is nearby.
 756    buffer.update(cx, |buffer, cx| {
 757        let offset = Point::new(4, 0).to_offset(buffer);
 758        let end = Point::new(4, 6).to_offset(buffer);
 759        buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
 760    });
 761
 762    let events = ep_store.update(cx, |ep_store, cx| {
 763        ep_store.edit_history_for_project(&project, cx)
 764    });
 765    assert_eq!(
 766        render_events_with_predicted(&events),
 767        vec![
 768            indoc! {"
 769                manual
 770                @@ -1,5 +1,5 @@
 771                -line 0
 772                -line 1
 773                +LINE ZERO
 774                +LINE ONE
 775                 line 2
 776                 line 3
 777                 line 4
 778            "},
 779            indoc! {"
 780                predicted
 781                @@ -1,7 +1,7 @@
 782                 LINE ZERO
 783                 LINE ONE
 784                -line 2
 785                -line 3
 786                +LINE TWO
 787                +LINE THREE
 788                 line 4
 789                 line 5
 790                 line 6
 791            "},
 792            indoc! {"
 793                manual
 794                @@ -2,7 +2,7 @@
 795                 LINE ONE
 796                 LINE TWO
 797                 LINE THREE
 798                -line 4
 799                +LINE FOUR
 800                 line 5
 801                 line 6
 802                 line 7
 803            "}
 804        ]
 805    );
 806
 807    // Case 6: If we then perform a manual edit at a *different* location (more than
 808    // 8 lines away), then the edits at the prior location can be merged with each
 809    // other, even if some are predicted and some are not. `predicted` means all
 810    // constituent edits were predicted.
 811    buffer.update(cx, |buffer, cx| {
 812        let offset = Point::new(14, 0).to_offset(buffer);
 813        let end = Point::new(14, 7).to_offset(buffer);
 814        buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
 815    });
 816
 817    let events = ep_store.update(cx, |ep_store, cx| {
 818        ep_store.edit_history_for_project(&project, cx)
 819    });
 820    assert_eq!(
 821        render_events_with_predicted(&events),
 822        vec![
 823            indoc! {"
 824                manual
 825                @@ -1,8 +1,8 @@
 826                -line 0
 827                -line 1
 828                -line 2
 829                -line 3
 830                -line 4
 831                +LINE ZERO
 832                +LINE ONE
 833                +LINE TWO
 834                +LINE THREE
 835                +LINE FOUR
 836                 line 5
 837                 line 6
 838                 line 7
 839            "},
 840            indoc! {"
 841                manual
 842                @@ -12,4 +12,4 @@
 843                 line 11
 844                 line 12
 845                 line 13
 846                -line 14
 847                +LINE FOURTEEN
 848            "}
 849        ]
 850    );
 851}
 852
 853#[gpui::test]
 854async fn test_empty_prediction(cx: &mut TestAppContext) {
 855    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 856    let fs = FakeFs::new(cx.executor());
 857    fs.insert_tree(
 858        "/root",
 859        json!({
 860            "foo.md":  "Hello!\nHow\nBye\n"
 861        }),
 862    )
 863    .await;
 864    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 865
 866    let buffer = project
 867        .update(cx, |project, cx| {
 868            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 869            project.open_buffer(path, cx)
 870        })
 871        .await
 872        .unwrap();
 873    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 874    let position = snapshot.anchor_before(language::Point::new(1, 3));
 875
 876    ep_store.update(cx, |ep_store, cx| {
 877        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 878    });
 879
 880    let (request, respond_tx) = requests.predict.next().await.unwrap();
 881    let response = model_response(&request, "");
 882    let id = response.request_id.clone();
 883    respond_tx.send(response).unwrap();
 884
 885    cx.run_until_parked();
 886
 887    ep_store.update(cx, |ep_store, cx| {
 888        assert!(
 889            ep_store
 890                .prediction_at(&buffer, None, &project, cx)
 891                .is_none()
 892        );
 893    });
 894
 895    // prediction is reported as rejected
 896    let (reject_request, _) = requests.reject.next().await.unwrap();
 897
 898    assert_eq!(
 899        &reject_request.rejections,
 900        &[EditPredictionRejection {
 901            request_id: id,
 902            reason: EditPredictionRejectReason::Empty,
 903            was_shown: false,
 904            model_version: None,
 905        }]
 906    );
 907}
 908
 909#[gpui::test]
 910async fn test_interpolated_empty(cx: &mut TestAppContext) {
 911    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 912    let fs = FakeFs::new(cx.executor());
 913    fs.insert_tree(
 914        "/root",
 915        json!({
 916            "foo.md":  "Hello!\nHow\nBye\n"
 917        }),
 918    )
 919    .await;
 920    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 921
 922    let buffer = project
 923        .update(cx, |project, cx| {
 924            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 925            project.open_buffer(path, cx)
 926        })
 927        .await
 928        .unwrap();
 929    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 930    let position = snapshot.anchor_before(language::Point::new(1, 3));
 931
 932    ep_store.update(cx, |ep_store, cx| {
 933        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 934    });
 935
 936    let (request, respond_tx) = requests.predict.next().await.unwrap();
 937
 938    buffer.update(cx, |buffer, cx| {
 939        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 940    });
 941
 942    let response = model_response(&request, SIMPLE_DIFF);
 943    let id = response.request_id.clone();
 944    respond_tx.send(response).unwrap();
 945
 946    cx.run_until_parked();
 947
 948    ep_store.update(cx, |ep_store, cx| {
 949        assert!(
 950            ep_store
 951                .prediction_at(&buffer, None, &project, cx)
 952                .is_none()
 953        );
 954    });
 955
 956    // prediction is reported as rejected
 957    let (reject_request, _) = requests.reject.next().await.unwrap();
 958
 959    assert_eq!(
 960        &reject_request.rejections,
 961        &[EditPredictionRejection {
 962            request_id: id,
 963            reason: EditPredictionRejectReason::InterpolatedEmpty,
 964            was_shown: false,
 965            model_version: None,
 966        }]
 967    );
 968}
 969
 970const SIMPLE_DIFF: &str = indoc! { r"
 971    --- a/root/foo.md
 972    +++ b/root/foo.md
 973    @@ ... @@
 974     Hello!
 975    -How
 976    +How are you?
 977     Bye
 978"};
 979
 980#[gpui::test]
 981async fn test_replace_current(cx: &mut TestAppContext) {
 982    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 983    let fs = FakeFs::new(cx.executor());
 984    fs.insert_tree(
 985        "/root",
 986        json!({
 987            "foo.md":  "Hello!\nHow\nBye\n"
 988        }),
 989    )
 990    .await;
 991    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 992
 993    let buffer = project
 994        .update(cx, |project, cx| {
 995            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 996            project.open_buffer(path, cx)
 997        })
 998        .await
 999        .unwrap();
1000    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1001    let position = snapshot.anchor_before(language::Point::new(1, 3));
1002
1003    ep_store.update(cx, |ep_store, cx| {
1004        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1005    });
1006
1007    let (request, respond_tx) = requests.predict.next().await.unwrap();
1008    let first_response = model_response(&request, SIMPLE_DIFF);
1009    let first_id = first_response.request_id.clone();
1010    respond_tx.send(first_response).unwrap();
1011
1012    cx.run_until_parked();
1013
1014    ep_store.update(cx, |ep_store, cx| {
1015        assert_eq!(
1016            ep_store
1017                .prediction_at(&buffer, None, &project, cx)
1018                .unwrap()
1019                .id
1020                .0,
1021            first_id
1022        );
1023    });
1024
1025    // a second request is triggered
1026    ep_store.update(cx, |ep_store, cx| {
1027        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1028    });
1029
1030    let (request, respond_tx) = requests.predict.next().await.unwrap();
1031    let second_response = model_response(&request, SIMPLE_DIFF);
1032    let second_id = second_response.request_id.clone();
1033    respond_tx.send(second_response).unwrap();
1034
1035    cx.run_until_parked();
1036
1037    ep_store.update(cx, |ep_store, cx| {
1038        // second replaces first
1039        assert_eq!(
1040            ep_store
1041                .prediction_at(&buffer, None, &project, cx)
1042                .unwrap()
1043                .id
1044                .0,
1045            second_id
1046        );
1047    });
1048
1049    // first is reported as replaced
1050    let (reject_request, _) = requests.reject.next().await.unwrap();
1051
1052    assert_eq!(
1053        &reject_request.rejections,
1054        &[EditPredictionRejection {
1055            request_id: first_id,
1056            reason: EditPredictionRejectReason::Replaced,
1057            was_shown: false,
1058            model_version: None,
1059        }]
1060    );
1061}
1062
1063#[gpui::test]
1064async fn test_current_preferred(cx: &mut TestAppContext) {
1065    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1066    let fs = FakeFs::new(cx.executor());
1067    fs.insert_tree(
1068        "/root",
1069        json!({
1070            "foo.md":  "Hello!\nHow\nBye\n"
1071        }),
1072    )
1073    .await;
1074    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1075
1076    let buffer = project
1077        .update(cx, |project, cx| {
1078            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1079            project.open_buffer(path, cx)
1080        })
1081        .await
1082        .unwrap();
1083    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1084    let position = snapshot.anchor_before(language::Point::new(1, 3));
1085
1086    ep_store.update(cx, |ep_store, cx| {
1087        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1088    });
1089
1090    let (request, respond_tx) = requests.predict.next().await.unwrap();
1091    let first_response = model_response(&request, SIMPLE_DIFF);
1092    let first_id = first_response.request_id.clone();
1093    respond_tx.send(first_response).unwrap();
1094
1095    cx.run_until_parked();
1096
1097    ep_store.update(cx, |ep_store, cx| {
1098        assert_eq!(
1099            ep_store
1100                .prediction_at(&buffer, None, &project, cx)
1101                .unwrap()
1102                .id
1103                .0,
1104            first_id
1105        );
1106    });
1107
1108    // a second request is triggered
1109    ep_store.update(cx, |ep_store, cx| {
1110        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1111    });
1112
1113    let (request, respond_tx) = requests.predict.next().await.unwrap();
1114    // worse than current prediction
1115    let second_response = model_response(
1116        &request,
1117        indoc! { r"
1118            --- a/root/foo.md
1119            +++ b/root/foo.md
1120            @@ ... @@
1121             Hello!
1122            -How
1123            +How are
1124             Bye
1125        "},
1126    );
1127    let second_id = second_response.request_id.clone();
1128    respond_tx.send(second_response).unwrap();
1129
1130    cx.run_until_parked();
1131
1132    ep_store.update(cx, |ep_store, cx| {
1133        // first is preferred over second
1134        assert_eq!(
1135            ep_store
1136                .prediction_at(&buffer, None, &project, cx)
1137                .unwrap()
1138                .id
1139                .0,
1140            first_id
1141        );
1142    });
1143
1144    // second is reported as rejected
1145    let (reject_request, _) = requests.reject.next().await.unwrap();
1146
1147    assert_eq!(
1148        &reject_request.rejections,
1149        &[EditPredictionRejection {
1150            request_id: second_id,
1151            reason: EditPredictionRejectReason::CurrentPreferred,
1152            was_shown: false,
1153            model_version: None,
1154        }]
1155    );
1156}
1157
1158#[gpui::test]
1159async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1160    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1161    let fs = FakeFs::new(cx.executor());
1162    fs.insert_tree(
1163        "/root",
1164        json!({
1165            "foo.md":  "Hello!\nHow\nBye\n"
1166        }),
1167    )
1168    .await;
1169    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1170
1171    let buffer = project
1172        .update(cx, |project, cx| {
1173            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1174            project.open_buffer(path, cx)
1175        })
1176        .await
1177        .unwrap();
1178    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1179    let position = snapshot.anchor_before(language::Point::new(1, 3));
1180
1181    // start two refresh tasks
1182    ep_store.update(cx, |ep_store, cx| {
1183        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1184    });
1185
1186    let (request1, respond_first) = requests.predict.next().await.unwrap();
1187
1188    ep_store.update(cx, |ep_store, cx| {
1189        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1190    });
1191
1192    let (request, respond_second) = requests.predict.next().await.unwrap();
1193
1194    // wait for throttle
1195    cx.run_until_parked();
1196
1197    // second responds first
1198    let second_response = model_response(&request, SIMPLE_DIFF);
1199    let second_id = second_response.request_id.clone();
1200    respond_second.send(second_response).unwrap();
1201
1202    cx.run_until_parked();
1203
1204    ep_store.update(cx, |ep_store, cx| {
1205        // current prediction is second
1206        assert_eq!(
1207            ep_store
1208                .prediction_at(&buffer, None, &project, cx)
1209                .unwrap()
1210                .id
1211                .0,
1212            second_id
1213        );
1214    });
1215
1216    let first_response = model_response(&request1, SIMPLE_DIFF);
1217    let first_id = first_response.request_id.clone();
1218    respond_first.send(first_response).unwrap();
1219
1220    cx.run_until_parked();
1221
1222    ep_store.update(cx, |ep_store, cx| {
1223        // current prediction is still second, since first was cancelled
1224        assert_eq!(
1225            ep_store
1226                .prediction_at(&buffer, None, &project, cx)
1227                .unwrap()
1228                .id
1229                .0,
1230            second_id
1231        );
1232    });
1233
1234    // first is reported as rejected
1235    let (reject_request, _) = requests.reject.next().await.unwrap();
1236
1237    cx.run_until_parked();
1238
1239    assert_eq!(
1240        &reject_request.rejections,
1241        &[EditPredictionRejection {
1242            request_id: first_id,
1243            reason: EditPredictionRejectReason::Canceled,
1244            was_shown: false,
1245            model_version: None,
1246        }]
1247    );
1248}
1249
1250#[gpui::test]
1251async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1252    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1253    let fs = FakeFs::new(cx.executor());
1254    fs.insert_tree(
1255        "/root",
1256        json!({
1257            "foo.md":  "Hello!\nHow\nBye\n"
1258        }),
1259    )
1260    .await;
1261    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1262
1263    let buffer = project
1264        .update(cx, |project, cx| {
1265            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1266            project.open_buffer(path, cx)
1267        })
1268        .await
1269        .unwrap();
1270    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1271    let position = snapshot.anchor_before(language::Point::new(1, 3));
1272
1273    // start two refresh tasks
1274    ep_store.update(cx, |ep_store, cx| {
1275        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1276    });
1277
1278    let (request1, respond_first) = requests.predict.next().await.unwrap();
1279
1280    ep_store.update(cx, |ep_store, cx| {
1281        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1282    });
1283
1284    let (request2, respond_second) = requests.predict.next().await.unwrap();
1285
1286    // wait for throttle, so requests are sent
1287    cx.run_until_parked();
1288
1289    ep_store.update(cx, |ep_store, cx| {
1290        // start a third request
1291        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1292
1293        // 2 are pending, so 2nd is cancelled
1294        assert_eq!(
1295            ep_store
1296                .get_or_init_project(&project, cx)
1297                .cancelled_predictions
1298                .iter()
1299                .copied()
1300                .collect::<Vec<_>>(),
1301            [1]
1302        );
1303    });
1304
1305    // wait for throttle
1306    cx.run_until_parked();
1307
1308    let (request3, respond_third) = requests.predict.next().await.unwrap();
1309
1310    let first_response = model_response(&request1, SIMPLE_DIFF);
1311    let first_id = first_response.request_id.clone();
1312    respond_first.send(first_response).unwrap();
1313
1314    cx.run_until_parked();
1315
1316    ep_store.update(cx, |ep_store, cx| {
1317        // current prediction is first
1318        assert_eq!(
1319            ep_store
1320                .prediction_at(&buffer, None, &project, cx)
1321                .unwrap()
1322                .id
1323                .0,
1324            first_id
1325        );
1326    });
1327
1328    let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1329    let cancelled_id = cancelled_response.request_id.clone();
1330    respond_second.send(cancelled_response).unwrap();
1331
1332    cx.run_until_parked();
1333
1334    ep_store.update(cx, |ep_store, cx| {
1335        // current prediction is still first, since second was cancelled
1336        assert_eq!(
1337            ep_store
1338                .prediction_at(&buffer, None, &project, cx)
1339                .unwrap()
1340                .id
1341                .0,
1342            first_id
1343        );
1344    });
1345
1346    let third_response = model_response(&request3, SIMPLE_DIFF);
1347    let third_response_id = third_response.request_id.clone();
1348    respond_third.send(third_response).unwrap();
1349
1350    cx.run_until_parked();
1351
1352    ep_store.update(cx, |ep_store, cx| {
1353        // third completes and replaces first
1354        assert_eq!(
1355            ep_store
1356                .prediction_at(&buffer, None, &project, cx)
1357                .unwrap()
1358                .id
1359                .0,
1360            third_response_id
1361        );
1362    });
1363
1364    // second is reported as rejected
1365    let (reject_request, _) = requests.reject.next().await.unwrap();
1366
1367    cx.run_until_parked();
1368
1369    assert_eq!(
1370        &reject_request.rejections,
1371        &[
1372            EditPredictionRejection {
1373                request_id: cancelled_id,
1374                reason: EditPredictionRejectReason::Canceled,
1375                was_shown: false,
1376                model_version: None,
1377            },
1378            EditPredictionRejection {
1379                request_id: first_id,
1380                reason: EditPredictionRejectReason::Replaced,
1381                was_shown: false,
1382                model_version: None,
1383            }
1384        ]
1385    );
1386}
1387
1388#[gpui::test]
1389async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
1390    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1391
1392    let fs = FakeFs::new(cx.executor());
1393    fs.insert_tree(
1394        "/root",
1395        json!({
1396            "foo.md":  "Hello!\nHow\nBye\n",
1397            "bar.md": "Hola!\nComo\nAdios\n"
1398        }),
1399    )
1400    .await;
1401    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1402
1403    let buffer = project
1404        .update(cx, |project, cx| {
1405            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1406            project.set_active_path(Some(path.clone()), cx);
1407            project.open_buffer(path, cx)
1408        })
1409        .await
1410        .unwrap();
1411    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1412    let position = snapshot.anchor_before(language::Point::new(1, 3));
1413
1414    ep_store.update(cx, |ep_store, cx| {
1415        ep_store.register_project(&project, cx);
1416        ep_store.register_buffer(&buffer, &project, cx);
1417    });
1418
1419    // First edit request - no prior edit, so not throttled.
1420    ep_store.update(cx, |ep_store, cx| {
1421        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1422    });
1423    let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
1424    edit_response_tx.send(empty_response()).unwrap();
1425    cx.run_until_parked();
1426
1427    let diagnostic = lsp::Diagnostic {
1428        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1429        severity: Some(lsp::DiagnosticSeverity::ERROR),
1430        message: "Sentence is incomplete".to_string(),
1431        ..Default::default()
1432    };
1433
1434    // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
1435    project.update(cx, |project, cx| {
1436        project.lsp_store().update(cx, |lsp_store, cx| {
1437            lsp_store
1438                .update_diagnostics(
1439                    LanguageServerId(0),
1440                    lsp::PublishDiagnosticsParams {
1441                        uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
1442                        diagnostics: vec![diagnostic],
1443                        version: None,
1444                    },
1445                    None,
1446                    language::DiagnosticSourceKind::Pushed,
1447                    &[],
1448                    cx,
1449                )
1450                .unwrap();
1451        });
1452    });
1453    let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
1454    jump_response_tx.send(empty_response()).unwrap();
1455    cx.run_until_parked();
1456
1457    // Second edit request - should be throttled by the first edit.
1458    ep_store.update(cx, |ep_store, cx| {
1459        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1460    });
1461    assert_no_predict_request_ready(&mut requests.predict);
1462
1463    // Second jump request - should be throttled by the first jump.
1464    ep_store.update(cx, |ep_store, cx| {
1465        ep_store.refresh_prediction_from_diagnostics(
1466            project.clone(),
1467            DiagnosticSearchScope::Global,
1468            cx,
1469        );
1470    });
1471    assert_no_predict_request_ready(&mut requests.predict);
1472
1473    // Wait for both throttles to expire.
1474    cx.background_executor
1475        .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
1476    cx.background_executor.run_until_parked();
1477    cx.run_until_parked();
1478
1479    // Both requests should now go through.
1480    let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
1481    response_tx_1.send(empty_response()).unwrap();
1482    cx.run_until_parked();
1483
1484    let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
1485    response_tx_2.send(empty_response()).unwrap();
1486    cx.run_until_parked();
1487}
1488
1489#[gpui::test]
1490async fn test_rejections_flushing(cx: &mut TestAppContext) {
1491    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1492
1493    ep_store.update(cx, |ep_store, cx| {
1494        ep_store.reject_prediction(
1495            EditPredictionId("test-1".into()),
1496            EditPredictionRejectReason::Discarded,
1497            false,
1498            None,
1499            cx,
1500        );
1501        ep_store.reject_prediction(
1502            EditPredictionId("test-2".into()),
1503            EditPredictionRejectReason::Canceled,
1504            true,
1505            None,
1506            cx,
1507        );
1508    });
1509
1510    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1511    cx.run_until_parked();
1512
1513    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1514    respond_tx.send(()).unwrap();
1515
1516    // batched
1517    assert_eq!(reject_request.rejections.len(), 2);
1518    assert_eq!(
1519        reject_request.rejections[0],
1520        EditPredictionRejection {
1521            request_id: "test-1".to_string(),
1522            reason: EditPredictionRejectReason::Discarded,
1523            was_shown: false,
1524            model_version: None,
1525        }
1526    );
1527    assert_eq!(
1528        reject_request.rejections[1],
1529        EditPredictionRejection {
1530            request_id: "test-2".to_string(),
1531            reason: EditPredictionRejectReason::Canceled,
1532            was_shown: true,
1533            model_version: None,
1534        }
1535    );
1536
1537    // Reaching batch size limit sends without debounce
1538    ep_store.update(cx, |ep_store, cx| {
1539        for i in 0..70 {
1540            ep_store.reject_prediction(
1541                EditPredictionId(format!("batch-{}", i).into()),
1542                EditPredictionRejectReason::Discarded,
1543                false,
1544                None,
1545                cx,
1546            );
1547        }
1548    });
1549
1550    // First MAX/2 items are sent immediately
1551    cx.run_until_parked();
1552    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1553    respond_tx.send(()).unwrap();
1554
1555    assert_eq!(reject_request.rejections.len(), 50);
1556    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1557    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1558
1559    // Remaining items are debounced with the next batch
1560    cx.executor().advance_clock(Duration::from_secs(15));
1561    cx.run_until_parked();
1562
1563    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1564    respond_tx.send(()).unwrap();
1565
1566    assert_eq!(reject_request.rejections.len(), 20);
1567    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1568    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1569
1570    // Request failure
1571    ep_store.update(cx, |ep_store, cx| {
1572        ep_store.reject_prediction(
1573            EditPredictionId("retry-1".into()),
1574            EditPredictionRejectReason::Discarded,
1575            false,
1576            None,
1577            cx,
1578        );
1579    });
1580
1581    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1582    cx.run_until_parked();
1583
1584    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1585    assert_eq!(reject_request.rejections.len(), 1);
1586    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1587    // Simulate failure
1588    drop(_respond_tx);
1589
1590    // Add another rejection
1591    ep_store.update(cx, |ep_store, cx| {
1592        ep_store.reject_prediction(
1593            EditPredictionId("retry-2".into()),
1594            EditPredictionRejectReason::Discarded,
1595            false,
1596            None,
1597            cx,
1598        );
1599    });
1600
1601    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1602    cx.run_until_parked();
1603
1604    // Retry should include both the failed item and the new one
1605    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1606    respond_tx.send(()).unwrap();
1607
1608    assert_eq!(reject_request.rejections.len(), 2);
1609    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1610    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1611}
1612
1613// Skipped until we start including diagnostics in prompt
1614// #[gpui::test]
1615// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1616//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1617//     let fs = FakeFs::new(cx.executor());
1618//     fs.insert_tree(
1619//         "/root",
1620//         json!({
1621//             "foo.md": "Hello!\nBye"
1622//         }),
1623//     )
1624//     .await;
1625//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1626
1627//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1628//     let diagnostic = lsp::Diagnostic {
1629//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1630//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1631//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1632//         ..Default::default()
1633//     };
1634
1635//     project.update(cx, |project, cx| {
1636//         project.lsp_store().update(cx, |lsp_store, cx| {
1637//             // Create some diagnostics
1638//             lsp_store
1639//                 .update_diagnostics(
1640//                     LanguageServerId(0),
1641//                     lsp::PublishDiagnosticsParams {
1642//                         uri: path_to_buffer_uri.clone(),
1643//                         diagnostics: vec![diagnostic],
1644//                         version: None,
1645//                     },
1646//                     None,
1647//                     language::DiagnosticSourceKind::Pushed,
1648//                     &[],
1649//                     cx,
1650//                 )
1651//                 .unwrap();
1652//         });
1653//     });
1654
1655//     let buffer = project
1656//         .update(cx, |project, cx| {
1657//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1658//             project.open_buffer(path, cx)
1659//         })
1660//         .await
1661//         .unwrap();
1662
1663//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1664//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1665
1666//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1667//         ep_store.request_prediction(&project, &buffer, position, cx)
1668//     });
1669
1670//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1671
1672//     assert_eq!(request.diagnostic_groups.len(), 1);
1673//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1674//         .unwrap();
1675//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1676//     assert_eq!(
1677//         value,
1678//         json!({
1679//             "entries": [{
1680//                 "range": {
1681//                     "start": 8,
1682//                     "end": 10
1683//                 },
1684//                 "diagnostic": {
1685//                     "source": null,
1686//                     "code": null,
1687//                     "code_description": null,
1688//                     "severity": 1,
1689//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1690//                     "markdown": null,
1691//                     "group_id": 0,
1692//                     "is_primary": true,
1693//                     "is_disk_based": false,
1694//                     "is_unnecessary": false,
1695//                     "source_kind": "Pushed",
1696//                     "data": null,
1697//                     "underline": true
1698//                 }
1699//             }],
1700//             "primary_ix": 0
1701//         })
1702//     );
1703// }
1704
1705// Generate a model response that would apply the given diff to the active file.
1706fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1707    let editable_range = request
1708        .input
1709        .excerpt_ranges
1710        .as_ref()
1711        .map(|r| zeta_prompt::excerpt_range_for_format(Default::default(), r).1)
1712        .unwrap_or(request.input.editable_range_in_excerpt.clone());
1713    let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string();
1714    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1715
1716    PredictEditsV3Response {
1717        request_id: Uuid::new_v4().to_string(),
1718        editable_range,
1719        output: new_excerpt,
1720        model_version: None,
1721    }
1722}
1723
1724fn empty_response() -> PredictEditsV3Response {
1725    PredictEditsV3Response {
1726        request_id: Uuid::new_v4().to_string(),
1727        editable_range: 0..0,
1728        output: String::new(),
1729        model_version: None,
1730    }
1731}
1732
1733fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1734    zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1735}
1736
1737fn assert_no_predict_request_ready(
1738    requests: &mut mpsc::UnboundedReceiver<(
1739        PredictEditsV3Request,
1740        oneshot::Sender<PredictEditsV3Response>,
1741    )>,
1742) {
1743    if requests.next().now_or_never().flatten().is_some() {
1744        panic!("Unexpected prediction request while throttled.");
1745    }
1746}
1747
1748struct RequestChannels {
1749    predict: mpsc::UnboundedReceiver<(
1750        PredictEditsV3Request,
1751        oneshot::Sender<PredictEditsV3Response>,
1752    )>,
1753    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1754}
1755
1756fn init_test_with_fake_client(
1757    cx: &mut TestAppContext,
1758) -> (Entity<EditPredictionStore>, RequestChannels) {
1759    cx.update(move |cx| {
1760        let settings_store = SettingsStore::test(cx);
1761        cx.set_global(settings_store);
1762        zlog::init_test();
1763
1764        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1765        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1766
1767        let http_client = FakeHttpClient::create({
1768            move |req| {
1769                let uri = req.uri().path().to_string();
1770                let mut body = req.into_body();
1771                let predict_req_tx = predict_req_tx.clone();
1772                let reject_req_tx = reject_req_tx.clone();
1773                async move {
1774                    let resp = match uri.as_str() {
1775                        "/client/llm_tokens" => serde_json::to_string(&json!({
1776                            "token": "test"
1777                        }))
1778                        .unwrap(),
1779                        "/predict_edits/v3" => {
1780                            let mut buf = Vec::new();
1781                            body.read_to_end(&mut buf).await.ok();
1782                            let decompressed = zstd::decode_all(&buf[..]).unwrap();
1783                            let req = serde_json::from_slice(&decompressed).unwrap();
1784
1785                            let (res_tx, res_rx) = oneshot::channel();
1786                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1787                            serde_json::to_string(&res_rx.await?).unwrap()
1788                        }
1789                        "/predict_edits/reject" => {
1790                            let mut buf = Vec::new();
1791                            body.read_to_end(&mut buf).await.ok();
1792                            let req = serde_json::from_slice(&buf).unwrap();
1793
1794                            let (res_tx, res_rx) = oneshot::channel();
1795                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1796                            serde_json::to_string(&res_rx.await?).unwrap()
1797                        }
1798                        _ => {
1799                            panic!("Unexpected path: {}", uri)
1800                        }
1801                    };
1802
1803                    Ok(Response::builder().body(resp.into()).unwrap())
1804                }
1805            }
1806        });
1807
1808        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1809        client.cloud_client().set_credentials(1, "test".into());
1810
1811        language_model::init(client.clone(), cx);
1812
1813        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1814        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1815
1816        (
1817            ep_store,
1818            RequestChannels {
1819                predict: predict_req_rx,
1820                reject: reject_req_rx,
1821            },
1822        )
1823    })
1824}
1825
1826#[gpui::test]
1827async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1828    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1829    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1830        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1831    });
1832
1833    let edit_preview = cx
1834        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1835        .await;
1836
1837    let prediction = EditPrediction {
1838        edits,
1839        cursor_position: None,
1840        edit_preview,
1841        buffer: buffer.clone(),
1842        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1843        id: EditPredictionId("the-id".into()),
1844        inputs: ZetaPromptInput {
1845            events: Default::default(),
1846            related_files: Default::default(),
1847            cursor_path: Path::new("").into(),
1848            cursor_excerpt: "".into(),
1849            editable_range_in_excerpt: 0..0,
1850            cursor_offset_in_excerpt: 0,
1851            excerpt_start_row: None,
1852            excerpt_ranges: None,
1853            preferred_model: None,
1854            in_open_source_repo: false,
1855            can_collect_data: false,
1856        },
1857        buffer_snapshotted_at: Instant::now(),
1858        response_received_at: Instant::now(),
1859        model_version: None,
1860    };
1861
1862    cx.update(|cx| {
1863        assert_eq!(
1864            from_completion_edits(
1865                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1866                &buffer,
1867                cx
1868            ),
1869            vec![(2..5, "REM".into()), (9..11, "".into())]
1870        );
1871
1872        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1873        assert_eq!(
1874            from_completion_edits(
1875                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1876                &buffer,
1877                cx
1878            ),
1879            vec![(2..2, "REM".into()), (6..8, "".into())]
1880        );
1881
1882        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1883        assert_eq!(
1884            from_completion_edits(
1885                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1886                &buffer,
1887                cx
1888            ),
1889            vec![(2..5, "REM".into()), (9..11, "".into())]
1890        );
1891
1892        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1893        assert_eq!(
1894            from_completion_edits(
1895                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1896                &buffer,
1897                cx
1898            ),
1899            vec![(3..3, "EM".into()), (7..9, "".into())]
1900        );
1901
1902        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1903        assert_eq!(
1904            from_completion_edits(
1905                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1906                &buffer,
1907                cx
1908            ),
1909            vec![(4..4, "M".into()), (8..10, "".into())]
1910        );
1911
1912        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1913        assert_eq!(
1914            from_completion_edits(
1915                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1916                &buffer,
1917                cx
1918            ),
1919            vec![(9..11, "".into())]
1920        );
1921
1922        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1923        assert_eq!(
1924            from_completion_edits(
1925                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1926                &buffer,
1927                cx
1928            ),
1929            vec![(4..4, "M".into()), (8..10, "".into())]
1930        );
1931
1932        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1933        assert_eq!(
1934            from_completion_edits(
1935                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1936                &buffer,
1937                cx
1938            ),
1939            vec![(4..4, "M".into())]
1940        );
1941
1942        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1943        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1944    })
1945}
1946
1947#[gpui::test]
1948async fn test_clean_up_diff(cx: &mut TestAppContext) {
1949    init_test(cx);
1950
1951    assert_eq!(
1952        apply_edit_prediction(
1953            indoc! {"
1954                    fn main() {
1955                        let word_1 = \"lorem\";
1956                        let range = word.len()..word.len();
1957                    }
1958                "},
1959            indoc! {"
1960                    fn main() {
1961                        let word_1 = \"lorem\";
1962                        let range = word_1.len()..word_1.len();
1963                    }
1964                "},
1965            cx,
1966        )
1967        .await,
1968        indoc! {"
1969                fn main() {
1970                    let word_1 = \"lorem\";
1971                    let range = word_1.len()..word_1.len();
1972                }
1973            "},
1974    );
1975
1976    assert_eq!(
1977        apply_edit_prediction(
1978            indoc! {"
1979                    fn main() {
1980                        let story = \"the quick\"
1981                    }
1982                "},
1983            indoc! {"
1984                    fn main() {
1985                        let story = \"the quick brown fox jumps over the lazy dog\";
1986                    }
1987                "},
1988            cx,
1989        )
1990        .await,
1991        indoc! {"
1992                fn main() {
1993                    let story = \"the quick brown fox jumps over the lazy dog\";
1994                }
1995            "},
1996    );
1997}
1998
1999#[gpui::test]
2000async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2001    init_test(cx);
2002
2003    let buffer_content = "lorem\n";
2004    let completion_response = "lorem\nipsum\n";
2005
2006    assert_eq!(
2007        apply_edit_prediction(buffer_content, completion_response, cx).await,
2008        "lorem\nipsum\n"
2009    );
2010}
2011
2012#[gpui::test]
2013async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
2014    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
2015    // When the buffer ends without a trailing newline, but the model returns output
2016    // with a trailing newline, zeta2 should normalize both sides before diffing
2017    // so no spurious newline is inserted.
2018    let (ep_store, mut requests) = init_test_with_fake_client(cx);
2019    let fs = FakeFs::new(cx.executor());
2020
2021    // Single line buffer with no trailing newline
2022    fs.insert_tree(
2023        "/root",
2024        json!({
2025            "foo.txt": "hello"
2026        }),
2027    )
2028    .await;
2029    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2030
2031    let buffer = project
2032        .update(cx, |project, cx| {
2033            let path = project
2034                .find_project_path(path!("root/foo.txt"), cx)
2035                .unwrap();
2036            project.open_buffer(path, cx)
2037        })
2038        .await
2039        .unwrap();
2040
2041    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2042    let position = snapshot.anchor_before(language::Point::new(0, 5));
2043
2044    ep_store.update(cx, |ep_store, cx| {
2045        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2046    });
2047
2048    let (request, respond_tx) = requests.predict.next().await.unwrap();
2049
2050    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
2051    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
2052    let excerpt_length = request.input.cursor_excerpt.len();
2053    let response = PredictEditsV3Response {
2054        request_id: Uuid::new_v4().to_string(),
2055        output: "hello world\n".to_string(),
2056        editable_range: 0..excerpt_length,
2057        model_version: None,
2058    };
2059    respond_tx.send(response).unwrap();
2060
2061    cx.run_until_parked();
2062
2063    // The prediction should insert " world" without adding a newline
2064    ep_store.update(cx, |ep_store, cx| {
2065        let prediction = ep_store
2066            .prediction_at(&buffer, None, &project, cx)
2067            .expect("should have prediction");
2068        let edits: Vec<_> = prediction
2069            .edits
2070            .iter()
2071            .map(|(range, text)| {
2072                let snapshot = buffer.read(cx).snapshot();
2073                (range.to_offset(&snapshot), text.clone())
2074            })
2075            .collect();
2076        assert_eq!(edits, vec![(5..5, " world".into())]);
2077    });
2078}
2079
2080fn init_test(cx: &mut TestAppContext) {
2081    cx.update(|cx| {
2082        let settings_store = SettingsStore::test(cx);
2083        cx.set_global(settings_store);
2084    });
2085}
2086
2087async fn apply_edit_prediction(
2088    buffer_content: &str,
2089    completion_response: &str,
2090    cx: &mut TestAppContext,
2091) -> String {
2092    let fs = project::FakeFs::new(cx.executor());
2093    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2094    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2095    let (ep_store, response) = make_test_ep_store(&project, cx).await;
2096    *response.lock() = completion_response.to_string();
2097    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2098    buffer.update(cx, |buffer, cx| {
2099        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2100    });
2101    buffer.read_with(cx, |buffer, _| buffer.text())
2102}
2103
2104async fn run_edit_prediction(
2105    buffer: &Entity<Buffer>,
2106    project: &Entity<Project>,
2107    ep_store: &Entity<EditPredictionStore>,
2108    cx: &mut TestAppContext,
2109) -> EditPrediction {
2110    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2111    ep_store.update(cx, |ep_store, cx| {
2112        ep_store.register_buffer(buffer, &project, cx)
2113    });
2114    cx.background_executor.run_until_parked();
2115    let prediction_task = ep_store.update(cx, |ep_store, cx| {
2116        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2117    });
2118    prediction_task.await.unwrap().unwrap().prediction.unwrap()
2119}
2120
2121async fn make_test_ep_store(
2122    project: &Entity<Project>,
2123    cx: &mut TestAppContext,
2124) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
2125    let default_response = "hello world\n".to_string();
2126    let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
2127    let http_client = FakeHttpClient::create({
2128        let completion_response = completion_response.clone();
2129        let mut next_request_id = 0;
2130        move |req| {
2131            let completion_response = completion_response.clone();
2132            let method = req.method().clone();
2133            let uri = req.uri().path().to_string();
2134            let mut body = req.into_body();
2135            async move {
2136                match (method, uri.as_str()) {
2137                    (Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2138                        .status(200)
2139                        .body(
2140                            serde_json::to_string(&CreateLlmTokenResponse {
2141                                token: LlmToken("the-llm-token".to_string()),
2142                            })
2143                            .unwrap()
2144                            .into(),
2145                        )
2146                        .unwrap()),
2147                    (Method::POST, "/predict_edits/v3") => {
2148                        let mut buf = Vec::new();
2149                        body.read_to_end(&mut buf).await.ok();
2150                        let decompressed = zstd::decode_all(&buf[..]).unwrap();
2151                        let req: PredictEditsV3Request =
2152                            serde_json::from_slice(&decompressed).unwrap();
2153
2154                        next_request_id += 1;
2155                        Ok(http_client::Response::builder()
2156                            .status(200)
2157                            .body(
2158                                serde_json::to_string(&PredictEditsV3Response {
2159                                    request_id: format!("request-{next_request_id}"),
2160                                    editable_range: 0..req.input.cursor_excerpt.len(),
2161                                    output: completion_response.lock().clone(),
2162                                    model_version: None,
2163                                })
2164                                .unwrap()
2165                                .into(),
2166                            )
2167                            .unwrap())
2168                    }
2169                    _ => Ok(http_client::Response::builder()
2170                        .status(404)
2171                        .body("Not Found".to_string().into())
2172                        .unwrap()),
2173                }
2174            }
2175        }
2176    });
2177
2178    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2179    cx.update(|cx| {
2180        RefreshLlmTokenListener::register(client.clone(), cx);
2181    });
2182    let _server = FakeServer::for_client(42, &client, cx).await;
2183
2184    let ep_store = cx.new(|cx| {
2185        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2186        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2187
2188        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2189        for worktree in worktrees {
2190            let worktree_id = worktree.read(cx).id();
2191            ep_store
2192                .get_or_init_project(project, cx)
2193                .license_detection_watchers
2194                .entry(worktree_id)
2195                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2196        }
2197
2198        ep_store
2199    });
2200
2201    (ep_store, completion_response)
2202}
2203
2204fn to_completion_edits(
2205    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2206    buffer: &Entity<Buffer>,
2207    cx: &App,
2208) -> Vec<(Range<Anchor>, Arc<str>)> {
2209    let buffer = buffer.read(cx);
2210    iterator
2211        .into_iter()
2212        .map(|(range, text)| {
2213            (
2214                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2215                text,
2216            )
2217        })
2218        .collect()
2219}
2220
2221fn from_completion_edits(
2222    editor_edits: &[(Range<Anchor>, Arc<str>)],
2223    buffer: &Entity<Buffer>,
2224    cx: &App,
2225) -> Vec<(Range<usize>, Arc<str>)> {
2226    let buffer = buffer.read(cx);
2227    editor_edits
2228        .iter()
2229        .map(|(range, text)| {
2230            (
2231                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2232                text.clone(),
2233            )
2234        })
2235        .collect()
2236}
2237
2238#[gpui::test]
2239async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2240    init_test(cx);
2241
2242    let fs = FakeFs::new(cx.executor());
2243    fs.insert_tree(
2244        "/project",
2245        serde_json::json!({
2246            "main.rs": "fn main() {\n    \n}\n"
2247        }),
2248    )
2249    .await;
2250
2251    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2252
2253    let http_client = FakeHttpClient::create(|_req| async move {
2254        Ok(gpui::http_client::Response::builder()
2255            .status(401)
2256            .body("Unauthorized".into())
2257            .unwrap())
2258    });
2259
2260    let client =
2261        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2262    cx.update(|cx| {
2263        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2264    });
2265
2266    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2267
2268    let buffer = project
2269        .update(cx, |project, cx| {
2270            let path = project
2271                .find_project_path(path!("/project/main.rs"), cx)
2272                .unwrap();
2273            project.open_buffer(path, cx)
2274        })
2275        .await
2276        .unwrap();
2277
2278    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2279    ep_store.update(cx, |ep_store, cx| {
2280        ep_store.register_buffer(&buffer, &project, cx)
2281    });
2282    cx.background_executor.run_until_parked();
2283
2284    let completion_task = ep_store.update(cx, |ep_store, cx| {
2285        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2286        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2287    });
2288
2289    let result = completion_task.await;
2290    assert!(
2291        result.is_err(),
2292        "Without authentication and without custom URL, prediction should fail"
2293    );
2294}
2295
2296#[gpui::test]
2297fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2298    let buffer = cx.new(|cx| {
2299        Buffer::local(
2300            indoc! {"
2301                zero
2302                one
2303                two
2304                three
2305                four
2306                five
2307                six
2308                seven
2309                eight
2310                nine
2311                ten
2312                eleven
2313                twelve
2314                thirteen
2315                fourteen
2316                fifteen
2317                sixteen
2318                seventeen
2319                eighteen
2320                nineteen
2321                twenty
2322                twenty-one
2323                twenty-two
2324                twenty-three
2325                twenty-four
2326            "},
2327            cx,
2328        )
2329    });
2330
2331    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2332
2333    buffer.update(cx, |buffer, cx| {
2334        let point = Point::new(12, 0);
2335        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2336        let point = Point::new(8, 0);
2337        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2338    });
2339
2340    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2341
2342    let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2343
2344    assert_eq!(
2345        diff,
2346        indoc! {"
2347            @@ -6,10 +6,12 @@
2348             five
2349             six
2350             seven
2351            +FIRST INSERTION
2352             eight
2353             nine
2354             ten
2355             eleven
2356            +SECOND INSERTION
2357             twelve
2358             thirteen
2359             fourteen
2360            "}
2361    );
2362}
2363
2364#[gpui::test]
2365async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
2366    fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
2367        let collab_replica = clock::ReplicaId::new(10);
2368        let anchor = buffer.read_with(cx, |buffer, _| {
2369            buffer.snapshot().anchor_before(Point::new(row, 0))
2370        });
2371        let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
2372            id: 1,
2373            start: anchor,
2374            end: anchor,
2375            reversed: false,
2376            goal: SelectionGoal::None,
2377        }]);
2378        buffer.update(cx, |buffer, cx| {
2379            buffer.apply_ops(
2380                [Operation::UpdateSelections {
2381                    selections,
2382                    lamport_timestamp: clock::Lamport {
2383                        replica_id: collab_replica,
2384                        value: 1,
2385                    },
2386                    line_mode: false,
2387                    cursor_shape: CursorShape::Bar,
2388                }],
2389                cx,
2390            );
2391        });
2392    }
2393
2394    fn publish_diagnostics(
2395        uri_path: &'static str,
2396        rows: &[u32],
2397        project: &Entity<Project>,
2398        cx: &mut TestAppContext,
2399    ) {
2400        let diagnostics: Vec<_> = rows
2401            .iter()
2402            .map(|&row| lsp::Diagnostic {
2403                range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
2404                severity: Some(lsp::DiagnosticSeverity::ERROR),
2405                message: format!("error at row {row}"),
2406                ..Default::default()
2407            })
2408            .collect();
2409        project.update(cx, |project, cx| {
2410            project.lsp_store().update(cx, |lsp_store, cx| {
2411                lsp_store
2412                    .update_diagnostics(
2413                        LanguageServerId(0),
2414                        lsp::PublishDiagnosticsParams {
2415                            uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
2416                            diagnostics,
2417                            version: None,
2418                        },
2419                        None,
2420                        language::DiagnosticSourceKind::Pushed,
2421                        &[],
2422                        cx,
2423                    )
2424                    .expect("failed to update diagnostics");
2425            });
2426        });
2427    }
2428
2429    init_test(cx);
2430
2431    let mut lines = String::new();
2432    for i in 0..60 {
2433        lines.push_str(&format!("line {i}\n"));
2434    }
2435
2436    let fs = FakeFs::new(cx.executor());
2437    fs.insert_tree(
2438        "/root",
2439        json!({
2440            "active.txt": lines,
2441            "collab_file.txt": "error here\nsecond line\n",
2442            "free_file.txt": "another error\nsecond line\n",
2443        }),
2444    )
2445    .await;
2446    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2447
2448    let active_buffer = project
2449        .update(cx, |project, cx| {
2450            let path = project
2451                .find_project_path(path!("/root/active.txt"), cx)
2452                .expect("active.txt not found");
2453            project.set_active_path(Some(path.clone()), cx);
2454            project.open_buffer(path, cx)
2455        })
2456        .await
2457        .expect("failed to open active buffer");
2458
2459    set_collaborator_cursor(&active_buffer, 5, cx);
2460
2461    publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
2462
2463    cx.run_until_parked();
2464
2465    let cursor_point = Point::new(25, 0);
2466    let empty_search_range: Range<Point> = Default::default();
2467
2468    let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2469    let result = EditPredictionStore::next_diagnostic_location(
2470        active_buffer.clone(),
2471        &snapshot,
2472        empty_search_range.clone(),
2473        cursor_point,
2474        &project,
2475        &mut cx.to_async(),
2476    )
2477    .await
2478    .expect("next_diagnostic_location failed");
2479
2480    let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
2481    assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
2482    let result_row = result_buffer.read_with(cx, |buffer, _| {
2483        result_anchor.to_point(&buffer.snapshot()).row
2484    });
2485    assert_ne!(
2486        result_row, 3,
2487        "row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
2488    );
2489    assert!(
2490        result_row == 25 || result_row == 50,
2491        "expected row 25 or 50, got {result_row}"
2492    );
2493
2494    let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2495    let near_cursor_point = Point::new(4, 0);
2496    let result_near = EditPredictionStore::next_diagnostic_location(
2497        active_buffer.clone(),
2498        &snapshot_near,
2499        empty_search_range.clone(),
2500        near_cursor_point,
2501        &project,
2502        &mut cx.to_async(),
2503    )
2504    .await
2505    .expect("next_diagnostic_location failed");
2506
2507    let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
2508    let near_row =
2509        active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
2510    assert_eq!(
2511        near_row, 3,
2512        "row 3 should be included when local cursor (row 4) is also near the collaborator"
2513    );
2514
2515    let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2516    let far_cursor_point = Point::new(50, 0);
2517    let result_far = EditPredictionStore::next_diagnostic_location(
2518        active_buffer.clone(),
2519        &snapshot_far,
2520        empty_search_range.clone(),
2521        far_cursor_point,
2522        &project,
2523        &mut cx.to_async(),
2524    )
2525    .await
2526    .expect("next_diagnostic_location failed");
2527
2528    let (_, far_anchor) = result_far.expect("expected a diagnostic location");
2529    let far_row =
2530        active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
2531    assert_eq!(
2532        far_row, 50,
2533        "row 50 is near local cursor (row 50) and far from collaborator, should be picked"
2534    );
2535
2536    publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
2537    publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
2538    cx.run_until_parked();
2539
2540    let collab_buffer = project
2541        .update(cx, |project, cx| {
2542            let path = project
2543                .find_project_path(path!("/root/collab_file.txt"), cx)
2544                .expect("collab_file.txt not found");
2545            project.open_buffer(path, cx)
2546        })
2547        .await
2548        .expect("failed to open collab buffer");
2549
2550    set_collaborator_cursor(&collab_buffer, 0, cx);
2551    cx.run_until_parked();
2552
2553    let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
2554    let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
2555    let result_cross = EditPredictionStore::next_diagnostic_location(
2556        active_buffer.clone(),
2557        &snapshot_cross,
2558        no_same_file_search_range,
2559        Point::new(0, 0),
2560        &project,
2561        &mut cx.to_async(),
2562    )
2563    .await
2564    .expect("cross-file next_diagnostic_location failed");
2565
2566    let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
2567    let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
2568        buffer
2569            .file()
2570            .expect("buffer should have a file")
2571            .full_path(cx)
2572    });
2573    assert_eq!(
2574        cross_path,
2575        Path::new(path!("root/free_file.txt")),
2576        "should skip collab_file.txt (has collaborator) and pick free_file.txt"
2577    );
2578}
2579
2580#[gpui::test]
2581async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
2582    let (ep_store, _requests) = init_test_with_fake_client(cx);
2583    let fs = FakeFs::new(cx.executor());
2584
2585    // Buffer with two clearly separated regions:
2586    //   Region A = lines 0-9   (offsets 0..50)
2587    //   Region B = lines 20-29 (offsets 105..155)
2588    // A big gap in between so edits in one region never overlap the other.
2589    let mut content = String::new();
2590    for i in 0..30 {
2591        content.push_str(&format!("line {i:02}\n"));
2592    }
2593
2594    fs.insert_tree(
2595        "/root",
2596        json!({
2597            "foo.md": content.clone()
2598        }),
2599    )
2600    .await;
2601    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2602
2603    let buffer = project
2604        .update(cx, |project, cx| {
2605            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2606            project.open_buffer(path, cx)
2607        })
2608        .await
2609        .unwrap();
2610
2611    let settled_events: Arc<Mutex<Vec<(EditPredictionId, String)>>> =
2612        Arc::new(Mutex::new(Vec::new()));
2613
2614    ep_store.update(cx, |ep_store, cx| {
2615        ep_store.register_buffer(&buffer, &project, cx);
2616
2617        let settled_events = settled_events.clone();
2618        ep_store.settled_event_callback = Some(Box::new(move |id, text| {
2619            settled_events.lock().push((id, text));
2620        }));
2621    });
2622
2623    // --- Phase 1: edit in region A and enqueue prediction A ---
2624
2625    buffer.update(cx, |buffer, cx| {
2626        // Edit at the start of line 0.
2627        buffer.edit(vec![(0..0, "ADDED ")], None, cx);
2628    });
2629    cx.run_until_parked();
2630
2631    let snapshot_a = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2632
2633    // Region A: first 10 lines of the buffer.
2634    let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0));
2635    ep_store.update(cx, |ep_store, cx| {
2636        ep_store.enqueue_settled_prediction(
2637            EditPredictionId("prediction-a".into()),
2638            &project,
2639            &buffer,
2640            &snapshot_a,
2641            editable_region_a,
2642            cx,
2643        );
2644    });
2645
2646    // --- Phase 2: repeatedly edit in region A to keep it unsettled ---
2647
2648    // Let the worker process the channel message before we start advancing.
2649    cx.run_until_parked();
2650
2651    let mut region_a_edit_offset = 5;
2652    for _ in 0..3 {
2653        // Edit inside region A (not at the boundary) so `last_edit_at` is
2654        // updated before the worker's next wake.
2655        buffer.update(cx, |buffer, cx| {
2656            buffer.edit(
2657                vec![(region_a_edit_offset..region_a_edit_offset, "x")],
2658                None,
2659                cx,
2660            );
2661        });
2662        region_a_edit_offset += 1;
2663        cx.run_until_parked();
2664
2665        cx.executor()
2666            .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 2);
2667        cx.run_until_parked();
2668        assert!(
2669            settled_events.lock().is_empty(),
2670            "no settled events should fire while region A is still being edited"
2671        );
2672    }
2673
2674    // Still nothing settled.
2675    assert!(settled_events.lock().is_empty());
2676
2677    // --- Phase 3: edit in distinct region B, enqueue prediction B ---
2678    // Advance a small amount so B's quiescence window starts later than A's,
2679    // but not so much that A settles (A's last edit was at the start of
2680    // iteration 3, and it needs a full Q to settle).
2681    cx.executor()
2682        .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2683    cx.run_until_parked();
2684    assert!(settled_events.lock().is_empty());
2685
2686    let snapshot_b = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2687    let line_20_offset = snapshot_b.point_to_offset(Point::new(20, 0));
2688
2689    buffer.update(cx, |buffer, cx| {
2690        buffer.edit(vec![(line_20_offset..line_20_offset, "NEW ")], None, cx);
2691    });
2692    cx.run_until_parked();
2693
2694    let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2695    let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0));
2696    ep_store.update(cx, |ep_store, cx| {
2697        ep_store.enqueue_settled_prediction(
2698            EditPredictionId("prediction-b".into()),
2699            &project,
2700            &buffer,
2701            &snapshot_b2,
2702            editable_region_b,
2703            cx,
2704        );
2705    });
2706
2707    cx.run_until_parked();
2708    assert!(
2709        settled_events.lock().is_empty(),
2710        "neither prediction should have settled yet"
2711    );
2712
2713    // --- Phase 4: let enough time pass for region A to settle ---
2714    // A's last edit was at T_a (during the last loop iteration). The worker is
2715    // sleeping until T_a + Q. We advance just enough to reach that wake time
2716    // (Q/4 since we already advanced Q/4 in phase 3 on top of the loop's
2717    // 3*Q/2). At that point A has been quiet for Q and settles, but B was
2718    // enqueued only Q/4 ago and stays pending.
2719    cx.executor()
2720        .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
2721    cx.run_until_parked();
2722
2723    {
2724        let events = settled_events.lock().clone();
2725        assert_eq!(
2726            events.len(),
2727            1,
2728            "only prediction A should have settled, got: {events:?}"
2729        );
2730        assert_eq!(events[0].0, EditPredictionId("prediction-a".into()));
2731    }
2732
2733    // --- Phase 5: let more time pass for region B to settle ---
2734    // B's last edit was Q/4 before A settled. The worker rescheduled to
2735    // B's last_edit_at + Q, which is 3Q/4 from now.
2736    cx.executor()
2737        .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE * 3 / 4);
2738    cx.run_until_parked();
2739
2740    {
2741        let events = settled_events.lock().clone();
2742        assert_eq!(
2743            events.len(),
2744            2,
2745            "both predictions should have settled, got: {events:?}"
2746        );
2747        assert_eq!(events[1].0, EditPredictionId("prediction-b".into()));
2748    }
2749}
2750
2751#[ctor::ctor]
2752fn init_logger() {
2753    zlog::init_test();
2754}