edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
   3use client::{UserStore, test::FakeServer};
   4use clock::{FakeSystemClock, ReplicaId};
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
   8    RejectEditPredictionsBody,
   9    predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
  10};
  11use futures::{
  12    AsyncReadExt, StreamExt,
  13    channel::{mpsc, oneshot},
  14};
  15use gpui::App;
  16use gpui::{
  17    Entity, TestAppContext,
  18    http_client::{FakeHttpClient, Response},
  19};
  20use indoc::indoc;
  21use language::{Buffer, Point};
  22use lsp::LanguageServerId;
  23use parking_lot::Mutex;
  24use pretty_assertions::{assert_eq, assert_matches};
  25use project::{FakeFs, Project};
  26use serde_json::json;
  27use settings::SettingsStore;
  28use std::{path::Path, sync::Arc, time::Duration};
  29use util::{path, rel_path::rel_path};
  30use uuid::Uuid;
  31use zeta_prompt::ZetaPromptInput;
  32
  33use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  34
  35#[gpui::test]
  36async fn test_current_state(cx: &mut TestAppContext) {
  37    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  38    let fs = FakeFs::new(cx.executor());
  39    fs.insert_tree(
  40        "/root",
  41        json!({
  42            "1.txt": "Hello!\nHow\nBye\n",
  43            "2.txt": "Hola!\nComo\nAdios\n"
  44        }),
  45    )
  46    .await;
  47    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  48
  49    let buffer1 = project
  50        .update(cx, |project, cx| {
  51            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  52            project.set_active_path(Some(path.clone()), cx);
  53            project.open_buffer(path, cx)
  54        })
  55        .await
  56        .unwrap();
  57    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  58    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  59
  60    ep_store.update(cx, |ep_store, cx| {
  61        ep_store.register_project(&project, cx);
  62        ep_store.register_buffer(&buffer1, &project, cx);
  63    });
  64
  65    // Prediction for current file
  66
  67    ep_store.update(cx, |ep_store, cx| {
  68        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  69    });
  70    let (request, respond_tx) = requests.predict.next().await.unwrap();
  71
  72    respond_tx
  73        .send(model_response(
  74            &request,
  75            indoc! {r"
  76                --- a/root/1.txt
  77                +++ b/root/1.txt
  78                @@ ... @@
  79                 Hello!
  80                -How
  81                +How are you?
  82                 Bye
  83            "},
  84        ))
  85        .unwrap();
  86
  87    cx.run_until_parked();
  88
  89    ep_store.update(cx, |ep_store, cx| {
  90        let prediction = ep_store
  91            .prediction_at(&buffer1, None, &project, cx)
  92            .unwrap();
  93        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  94    });
  95
  96    ep_store.update(cx, |ep_store, cx| {
  97        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
  98    });
  99
 100    // Prediction for diagnostic in another file
 101
 102    let diagnostic = lsp::Diagnostic {
 103        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 104        severity: Some(lsp::DiagnosticSeverity::ERROR),
 105        message: "Sentence is incomplete".to_string(),
 106        ..Default::default()
 107    };
 108
 109    project.update(cx, |project, cx| {
 110        project.lsp_store().update(cx, |lsp_store, cx| {
 111            lsp_store
 112                .update_diagnostics(
 113                    LanguageServerId(0),
 114                    lsp::PublishDiagnosticsParams {
 115                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 116                        diagnostics: vec![diagnostic],
 117                        version: None,
 118                    },
 119                    None,
 120                    language::DiagnosticSourceKind::Pushed,
 121                    &[],
 122                    cx,
 123                )
 124                .unwrap();
 125        });
 126    });
 127
 128    let (request, respond_tx) = requests.predict.next().await.unwrap();
 129    respond_tx
 130        .send(model_response(
 131            &request,
 132            indoc! {r#"
 133                --- a/root/2.txt
 134                +++ b/root/2.txt
 135                @@ ... @@
 136                 Hola!
 137                -Como
 138                +Como estas?
 139                 Adios
 140            "#},
 141        ))
 142        .unwrap();
 143    cx.run_until_parked();
 144
 145    ep_store.update(cx, |ep_store, cx| {
 146        let prediction = ep_store
 147            .prediction_at(&buffer1, None, &project, cx)
 148            .unwrap();
 149        assert_matches!(
 150            prediction,
 151            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 152        );
 153    });
 154
 155    let buffer2 = project
 156        .update(cx, |project, cx| {
 157            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 158            project.open_buffer(path, cx)
 159        })
 160        .await
 161        .unwrap();
 162
 163    ep_store.update(cx, |ep_store, cx| {
 164        let prediction = ep_store
 165            .prediction_at(&buffer2, None, &project, cx)
 166            .unwrap();
 167        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 168    });
 169}
 170
 171#[gpui::test]
 172async fn test_simple_request(cx: &mut TestAppContext) {
 173    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 174    let fs = FakeFs::new(cx.executor());
 175    fs.insert_tree(
 176        "/root",
 177        json!({
 178            "foo.md":  "Hello!\nHow\nBye\n"
 179        }),
 180    )
 181    .await;
 182    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 183
 184    let buffer = project
 185        .update(cx, |project, cx| {
 186            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 187            project.open_buffer(path, cx)
 188        })
 189        .await
 190        .unwrap();
 191    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 192    let position = snapshot.anchor_before(language::Point::new(1, 3));
 193
 194    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 195        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 196    });
 197
 198    let (request, respond_tx) = requests.predict.next().await.unwrap();
 199
 200    // TODO Put back when we have a structured request again
 201    // assert_eq!(
 202    //     request.excerpt_path.as_ref(),
 203    //     Path::new(path!("root/foo.md"))
 204    // );
 205    // assert_eq!(
 206    //     request.cursor_point,
 207    //     Point {
 208    //         line: Line(1),
 209    //         column: 3
 210    //     }
 211    // );
 212
 213    respond_tx
 214        .send(model_response(
 215            &request,
 216            indoc! { r"
 217                --- a/root/foo.md
 218                +++ b/root/foo.md
 219                @@ ... @@
 220                 Hello!
 221                -How
 222                +How are you?
 223                 Bye
 224            "},
 225        ))
 226        .unwrap();
 227
 228    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 229
 230    assert_eq!(prediction.edits.len(), 1);
 231    assert_eq!(
 232        prediction.edits[0].0.to_point(&snapshot).start,
 233        language::Point::new(1, 3)
 234    );
 235    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 236}
 237
 238#[gpui::test]
 239async fn test_request_events(cx: &mut TestAppContext) {
 240    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 241    let fs = FakeFs::new(cx.executor());
 242    fs.insert_tree(
 243        "/root",
 244        json!({
 245            "foo.md": "Hello!\n\nBye\n"
 246        }),
 247    )
 248    .await;
 249    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 250
 251    let buffer = project
 252        .update(cx, |project, cx| {
 253            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 254            project.open_buffer(path, cx)
 255        })
 256        .await
 257        .unwrap();
 258
 259    ep_store.update(cx, |ep_store, cx| {
 260        ep_store.register_buffer(&buffer, &project, cx);
 261    });
 262
 263    buffer.update(cx, |buffer, cx| {
 264        buffer.edit(vec![(7..7, "How")], None, cx);
 265    });
 266
 267    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 268    let position = snapshot.anchor_before(language::Point::new(1, 3));
 269
 270    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 271        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 272    });
 273
 274    let (request, respond_tx) = requests.predict.next().await.unwrap();
 275
 276    let prompt = prompt_from_request(&request);
 277    assert!(
 278        prompt.contains(indoc! {"
 279        --- a/root/foo.md
 280        +++ b/root/foo.md
 281        @@ -1,3 +1,3 @@
 282         Hello!
 283        -
 284        +How
 285         Bye
 286    "}),
 287        "{prompt}"
 288    );
 289
 290    respond_tx
 291        .send(model_response(
 292            &request,
 293            indoc! {r#"
 294                --- a/root/foo.md
 295                +++ b/root/foo.md
 296                @@ ... @@
 297                 Hello!
 298                -How
 299                +How are you?
 300                 Bye
 301        "#},
 302        ))
 303        .unwrap();
 304
 305    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 306
 307    assert_eq!(prediction.edits.len(), 1);
 308    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 309}
 310
 311#[gpui::test]
 312async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 313    let (ep_store, _requests) = init_test_with_fake_client(cx);
 314    let fs = FakeFs::new(cx.executor());
 315    fs.insert_tree(
 316        "/root",
 317        json!({
 318            "foo.md": "Hello!\n\nBye\n"
 319        }),
 320    )
 321    .await;
 322    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 323
 324    let buffer = project
 325        .update(cx, |project, cx| {
 326            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 327            project.open_buffer(path, cx)
 328        })
 329        .await
 330        .unwrap();
 331
 332    ep_store.update(cx, |ep_store, cx| {
 333        ep_store.register_buffer(&buffer, &project, cx);
 334    });
 335
 336    // First burst: insert "How"
 337    buffer.update(cx, |buffer, cx| {
 338        buffer.edit(vec![(7..7, "How")], None, cx);
 339    });
 340
 341    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 342    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 343    cx.run_until_parked();
 344
 345    // Second burst: append " are you?" immediately after "How" on the same line.
 346    //
 347    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 348    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 349    buffer.update(cx, |buffer, cx| {
 350        buffer.edit(vec![(10..10, " are you?")], None, cx);
 351    });
 352
 353    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 354    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 355    buffer.update(cx, |buffer, cx| {
 356        buffer.edit(vec![(19..19, "!")], None, cx);
 357    });
 358
 359    // With time-based splitting, there are two distinct events.
 360    let events = ep_store.update(cx, |ep_store, cx| {
 361        ep_store.edit_history_for_project(&project, cx)
 362    });
 363    assert_eq!(events.len(), 2);
 364    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 365    assert_eq!(
 366        diff.as_str(),
 367        indoc! {"
 368            @@ -1,3 +1,3 @@
 369             Hello!
 370            -
 371            +How
 372             Bye
 373        "}
 374    );
 375
 376    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 377    assert_eq!(
 378        diff.as_str(),
 379        indoc! {"
 380            @@ -1,3 +1,3 @@
 381             Hello!
 382            -How
 383            +How are you?!
 384             Bye
 385        "}
 386    );
 387}
 388
 389#[gpui::test]
 390async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
 391    let (ep_store, _requests) = init_test_with_fake_client(cx);
 392    let fs = FakeFs::new(cx.executor());
 393
 394    // Create a file with 30 lines to test line-based coalescing
 395    let content = (1..=30)
 396        .map(|i| format!("Line {}\n", i))
 397        .collect::<String>();
 398    fs.insert_tree(
 399        "/root",
 400        json!({
 401            "foo.md": content
 402        }),
 403    )
 404    .await;
 405    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 406
 407    let buffer = project
 408        .update(cx, |project, cx| {
 409            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 410            project.open_buffer(path, cx)
 411        })
 412        .await
 413        .unwrap();
 414
 415    ep_store.update(cx, |ep_store, cx| {
 416        ep_store.register_buffer(&buffer, &project, cx);
 417    });
 418
 419    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 420    buffer.update(cx, |buffer, cx| {
 421        let start = Point::new(10, 0).to_offset(buffer);
 422        let end = Point::new(13, 0).to_offset(buffer);
 423        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 424    });
 425
 426    let events = ep_store.update(cx, |ep_store, cx| {
 427        ep_store.edit_history_for_project(&project, cx)
 428    });
 429    assert_eq!(
 430        render_events(&events),
 431        indoc! {"
 432            @@ -8,9 +8,8 @@
 433             Line 8
 434             Line 9
 435             Line 10
 436            -Line 11
 437            -Line 12
 438            -Line 13
 439            +Middle A
 440            +Middle B
 441             Line 14
 442             Line 15
 443             Line 16
 444        "},
 445        "After first edit"
 446    );
 447
 448    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 449    // This tests that coalescing considers the START of the existing range
 450    buffer.update(cx, |buffer, cx| {
 451        let offset = Point::new(5, 0).to_offset(buffer);
 452        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 453    });
 454
 455    let events = ep_store.update(cx, |ep_store, cx| {
 456        ep_store.edit_history_for_project(&project, cx)
 457    });
 458    assert_eq!(
 459        render_events(&events),
 460        indoc! {"
 461            @@ -3,14 +3,14 @@
 462             Line 3
 463             Line 4
 464             Line 5
 465            +Above
 466             Line 6
 467             Line 7
 468             Line 8
 469             Line 9
 470             Line 10
 471            -Line 11
 472            -Line 12
 473            -Line 13
 474            +Middle A
 475            +Middle B
 476             Line 14
 477             Line 15
 478             Line 16
 479        "},
 480        "After inserting above (should coalesce)"
 481    );
 482
 483    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 484    // This tests that coalescing considers the END of the existing range
 485    buffer.update(cx, |buffer, cx| {
 486        let offset = Point::new(14, 0).to_offset(buffer);
 487        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 488    });
 489
 490    let events = ep_store.update(cx, |ep_store, cx| {
 491        ep_store.edit_history_for_project(&project, cx)
 492    });
 493    assert_eq!(
 494        render_events(&events),
 495        indoc! {"
 496            @@ -3,15 +3,16 @@
 497             Line 3
 498             Line 4
 499             Line 5
 500            +Above
 501             Line 6
 502             Line 7
 503             Line 8
 504             Line 9
 505             Line 10
 506            -Line 11
 507            -Line 12
 508            -Line 13
 509            +Middle A
 510            +Middle B
 511             Line 14
 512            +Below
 513             Line 15
 514             Line 16
 515             Line 17
 516        "},
 517        "After inserting below (should coalesce)"
 518    );
 519
 520    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 521    // This should NOT coalesce - creates a new event
 522    buffer.update(cx, |buffer, cx| {
 523        let offset = Point::new(25, 0).to_offset(buffer);
 524        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 525    });
 526
 527    let events = ep_store.update(cx, |ep_store, cx| {
 528        ep_store.edit_history_for_project(&project, cx)
 529    });
 530    assert_eq!(
 531        render_events(&events),
 532        indoc! {"
 533            @@ -3,15 +3,16 @@
 534             Line 3
 535             Line 4
 536             Line 5
 537            +Above
 538             Line 6
 539             Line 7
 540             Line 8
 541             Line 9
 542             Line 10
 543            -Line 11
 544            -Line 12
 545            -Line 13
 546            +Middle A
 547            +Middle B
 548             Line 14
 549            +Below
 550             Line 15
 551             Line 16
 552             Line 17
 553
 554            ---
 555            @@ -23,6 +23,7 @@
 556             Line 22
 557             Line 23
 558             Line 24
 559            +Far below
 560             Line 25
 561             Line 26
 562             Line 27
 563        "},
 564        "After inserting far below (should NOT coalesce)"
 565    );
 566}
 567
 568fn render_events(events: &[StoredEvent]) -> String {
 569    events
 570        .iter()
 571        .map(|e| {
 572            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 573            diff.as_str()
 574        })
 575        .collect::<Vec<_>>()
 576        .join("\n---\n")
 577}
 578
 579fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
 580    events
 581        .iter()
 582        .map(|e| {
 583            let zeta_prompt::Event::BufferChange {
 584                diff, predicted, ..
 585            } = e.event.as_ref();
 586            let prefix = if *predicted { "predicted" } else { "manual" };
 587            format!("{}\n{}", prefix, diff)
 588        })
 589        .collect()
 590}
 591
 592#[gpui::test]
 593async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
 594    let (ep_store, _requests) = init_test_with_fake_client(cx);
 595    let fs = FakeFs::new(cx.executor());
 596    fs.insert_tree(
 597        "/root",
 598        json!({
 599            "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
 600        }),
 601    )
 602    .await;
 603    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 604
 605    let buffer = project
 606        .update(cx, |project, cx| {
 607            let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
 608            project.open_buffer(path, cx)
 609        })
 610        .await
 611        .unwrap();
 612
 613    ep_store.update(cx, |ep_store, cx| {
 614        ep_store.register_buffer(&buffer, &project, cx);
 615    });
 616
 617    // Case 1: Manual edits have `predicted` set to false.
 618    buffer.update(cx, |buffer, cx| {
 619        buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
 620    });
 621
 622    let events = ep_store.update(cx, |ep_store, cx| {
 623        ep_store.edit_history_for_project(&project, cx)
 624    });
 625
 626    assert_eq!(
 627        render_events_with_predicted(&events),
 628        vec![indoc! {"
 629            manual
 630            @@ -1,4 +1,4 @@
 631            -line 0
 632            +LINE ZERO
 633             line 1
 634             line 2
 635             line 3
 636        "}]
 637    );
 638
 639    // Case 2: Multiple successive manual edits near each other are merged into one
 640    // event with `predicted` set to false.
 641    buffer.update(cx, |buffer, cx| {
 642        let offset = Point::new(1, 0).to_offset(buffer);
 643        let end = Point::new(1, 6).to_offset(buffer);
 644        buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
 645    });
 646
 647    let events = ep_store.update(cx, |ep_store, cx| {
 648        ep_store.edit_history_for_project(&project, cx)
 649    });
 650    assert_eq!(
 651        render_events_with_predicted(&events),
 652        vec![indoc! {"
 653            manual
 654            @@ -1,5 +1,5 @@
 655            -line 0
 656            -line 1
 657            +LINE ZERO
 658            +LINE ONE
 659             line 2
 660             line 3
 661             line 4
 662        "}]
 663    );
 664
 665    // Case 3: Accepted predictions have `predicted` set to true.
 666    // Case 5: A manual edit that follows a predicted edit is not merged with the
 667    // predicted edit, even if it is nearby.
 668    ep_store.update(cx, |ep_store, cx| {
 669        buffer.update(cx, |buffer, cx| {
 670            let offset = Point::new(2, 0).to_offset(buffer);
 671            let end = Point::new(2, 6).to_offset(buffer);
 672            buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
 673        });
 674        ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
 675    });
 676
 677    let events = ep_store.update(cx, |ep_store, cx| {
 678        ep_store.edit_history_for_project(&project, cx)
 679    });
 680    assert_eq!(
 681        render_events_with_predicted(&events),
 682        vec![
 683            indoc! {"
 684                manual
 685                @@ -1,5 +1,5 @@
 686                -line 0
 687                -line 1
 688                +LINE ZERO
 689                +LINE ONE
 690                 line 2
 691                 line 3
 692                 line 4
 693            "},
 694            indoc! {"
 695                predicted
 696                @@ -1,6 +1,6 @@
 697                 LINE ZERO
 698                 LINE ONE
 699                -line 2
 700                +LINE TWO
 701                 line 3
 702                 line 4
 703                 line 5
 704            "}
 705        ]
 706    );
 707
 708    // Case 4: Multiple successive accepted predictions near each other are merged
 709    // into one event with `predicted` set to true.
 710    ep_store.update(cx, |ep_store, cx| {
 711        buffer.update(cx, |buffer, cx| {
 712            let offset = Point::new(3, 0).to_offset(buffer);
 713            let end = Point::new(3, 6).to_offset(buffer);
 714            buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
 715        });
 716        ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
 717    });
 718
 719    let events = ep_store.update(cx, |ep_store, cx| {
 720        ep_store.edit_history_for_project(&project, cx)
 721    });
 722    assert_eq!(
 723        render_events_with_predicted(&events),
 724        vec![
 725            indoc! {"
 726                manual
 727                @@ -1,5 +1,5 @@
 728                -line 0
 729                -line 1
 730                +LINE ZERO
 731                +LINE ONE
 732                 line 2
 733                 line 3
 734                 line 4
 735            "},
 736            indoc! {"
 737                predicted
 738                @@ -1,7 +1,7 @@
 739                 LINE ZERO
 740                 LINE ONE
 741                -line 2
 742                -line 3
 743                +LINE TWO
 744                +LINE THREE
 745                 line 4
 746                 line 5
 747                 line 6
 748            "}
 749        ]
 750    );
 751
 752    // Case 5 (continued): A manual edit that follows a predicted edit is not merged
 753    // with the predicted edit, even if it is nearby.
 754    buffer.update(cx, |buffer, cx| {
 755        let offset = Point::new(4, 0).to_offset(buffer);
 756        let end = Point::new(4, 6).to_offset(buffer);
 757        buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
 758    });
 759
 760    let events = ep_store.update(cx, |ep_store, cx| {
 761        ep_store.edit_history_for_project(&project, cx)
 762    });
 763    assert_eq!(
 764        render_events_with_predicted(&events),
 765        vec![
 766            indoc! {"
 767                manual
 768                @@ -1,5 +1,5 @@
 769                -line 0
 770                -line 1
 771                +LINE ZERO
 772                +LINE ONE
 773                 line 2
 774                 line 3
 775                 line 4
 776            "},
 777            indoc! {"
 778                predicted
 779                @@ -1,7 +1,7 @@
 780                 LINE ZERO
 781                 LINE ONE
 782                -line 2
 783                -line 3
 784                +LINE TWO
 785                +LINE THREE
 786                 line 4
 787                 line 5
 788                 line 6
 789            "},
 790            indoc! {"
 791                manual
 792                @@ -2,7 +2,7 @@
 793                 LINE ONE
 794                 LINE TWO
 795                 LINE THREE
 796                -line 4
 797                +LINE FOUR
 798                 line 5
 799                 line 6
 800                 line 7
 801            "}
 802        ]
 803    );
 804
 805    // Case 6: If we then perform a manual edit at a *different* location (more than
 806    // 8 lines away), then the edits at the prior location can be merged with each
 807    // other, even if some are predicted and some are not. `predicted` means all
 808    // constituent edits were predicted.
 809    buffer.update(cx, |buffer, cx| {
 810        let offset = Point::new(14, 0).to_offset(buffer);
 811        let end = Point::new(14, 7).to_offset(buffer);
 812        buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
 813    });
 814
 815    let events = ep_store.update(cx, |ep_store, cx| {
 816        ep_store.edit_history_for_project(&project, cx)
 817    });
 818    assert_eq!(
 819        render_events_with_predicted(&events),
 820        vec![
 821            indoc! {"
 822                manual
 823                @@ -1,8 +1,8 @@
 824                -line 0
 825                -line 1
 826                -line 2
 827                -line 3
 828                -line 4
 829                +LINE ZERO
 830                +LINE ONE
 831                +LINE TWO
 832                +LINE THREE
 833                +LINE FOUR
 834                 line 5
 835                 line 6
 836                 line 7
 837            "},
 838            indoc! {"
 839                manual
 840                @@ -12,4 +12,4 @@
 841                 line 11
 842                 line 12
 843                 line 13
 844                -line 14
 845                +LINE FOURTEEN
 846            "}
 847        ]
 848    );
 849}
 850
 851#[gpui::test]
 852async fn test_empty_prediction(cx: &mut TestAppContext) {
 853    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 854    let fs = FakeFs::new(cx.executor());
 855    fs.insert_tree(
 856        "/root",
 857        json!({
 858            "foo.md":  "Hello!\nHow\nBye\n"
 859        }),
 860    )
 861    .await;
 862    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 863
 864    let buffer = project
 865        .update(cx, |project, cx| {
 866            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 867            project.open_buffer(path, cx)
 868        })
 869        .await
 870        .unwrap();
 871    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 872    let position = snapshot.anchor_before(language::Point::new(1, 3));
 873
 874    ep_store.update(cx, |ep_store, cx| {
 875        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 876    });
 877
 878    let (request, respond_tx) = requests.predict.next().await.unwrap();
 879    let response = model_response(&request, "");
 880    let id = response.request_id.clone();
 881    respond_tx.send(response).unwrap();
 882
 883    cx.run_until_parked();
 884
 885    ep_store.update(cx, |ep_store, cx| {
 886        assert!(
 887            ep_store
 888                .prediction_at(&buffer, None, &project, cx)
 889                .is_none()
 890        );
 891    });
 892
 893    // prediction is reported as rejected
 894    let (reject_request, _) = requests.reject.next().await.unwrap();
 895
 896    assert_eq!(
 897        &reject_request.rejections,
 898        &[EditPredictionRejection {
 899            request_id: id,
 900            reason: EditPredictionRejectReason::Empty,
 901            was_shown: false
 902        }]
 903    );
 904}
 905
 906#[gpui::test]
 907async fn test_interpolated_empty(cx: &mut TestAppContext) {
 908    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 909    let fs = FakeFs::new(cx.executor());
 910    fs.insert_tree(
 911        "/root",
 912        json!({
 913            "foo.md":  "Hello!\nHow\nBye\n"
 914        }),
 915    )
 916    .await;
 917    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 918
 919    let buffer = project
 920        .update(cx, |project, cx| {
 921            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 922            project.open_buffer(path, cx)
 923        })
 924        .await
 925        .unwrap();
 926    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 927    let position = snapshot.anchor_before(language::Point::new(1, 3));
 928
 929    ep_store.update(cx, |ep_store, cx| {
 930        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 931    });
 932
 933    let (request, respond_tx) = requests.predict.next().await.unwrap();
 934
 935    buffer.update(cx, |buffer, cx| {
 936        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 937    });
 938
 939    let response = model_response(&request, SIMPLE_DIFF);
 940    let id = response.request_id.clone();
 941    respond_tx.send(response).unwrap();
 942
 943    cx.run_until_parked();
 944
 945    ep_store.update(cx, |ep_store, cx| {
 946        assert!(
 947            ep_store
 948                .prediction_at(&buffer, None, &project, cx)
 949                .is_none()
 950        );
 951    });
 952
 953    // prediction is reported as rejected
 954    let (reject_request, _) = requests.reject.next().await.unwrap();
 955
 956    assert_eq!(
 957        &reject_request.rejections,
 958        &[EditPredictionRejection {
 959            request_id: id,
 960            reason: EditPredictionRejectReason::InterpolatedEmpty,
 961            was_shown: false
 962        }]
 963    );
 964}
 965
 966const SIMPLE_DIFF: &str = indoc! { r"
 967    --- a/root/foo.md
 968    +++ b/root/foo.md
 969    @@ ... @@
 970     Hello!
 971    -How
 972    +How are you?
 973     Bye
 974"};
 975
 976#[gpui::test]
 977async fn test_replace_current(cx: &mut TestAppContext) {
 978    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 979    let fs = FakeFs::new(cx.executor());
 980    fs.insert_tree(
 981        "/root",
 982        json!({
 983            "foo.md":  "Hello!\nHow\nBye\n"
 984        }),
 985    )
 986    .await;
 987    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 988
 989    let buffer = project
 990        .update(cx, |project, cx| {
 991            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 992            project.open_buffer(path, cx)
 993        })
 994        .await
 995        .unwrap();
 996    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 997    let position = snapshot.anchor_before(language::Point::new(1, 3));
 998
 999    ep_store.update(cx, |ep_store, cx| {
1000        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1001    });
1002
1003    let (request, respond_tx) = requests.predict.next().await.unwrap();
1004    let first_response = model_response(&request, SIMPLE_DIFF);
1005    let first_id = first_response.request_id.clone();
1006    respond_tx.send(first_response).unwrap();
1007
1008    cx.run_until_parked();
1009
1010    ep_store.update(cx, |ep_store, cx| {
1011        assert_eq!(
1012            ep_store
1013                .prediction_at(&buffer, None, &project, cx)
1014                .unwrap()
1015                .id
1016                .0,
1017            first_id
1018        );
1019    });
1020
1021    // a second request is triggered
1022    ep_store.update(cx, |ep_store, cx| {
1023        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1024    });
1025
1026    let (request, respond_tx) = requests.predict.next().await.unwrap();
1027    let second_response = model_response(&request, SIMPLE_DIFF);
1028    let second_id = second_response.request_id.clone();
1029    respond_tx.send(second_response).unwrap();
1030
1031    cx.run_until_parked();
1032
1033    ep_store.update(cx, |ep_store, cx| {
1034        // second replaces first
1035        assert_eq!(
1036            ep_store
1037                .prediction_at(&buffer, None, &project, cx)
1038                .unwrap()
1039                .id
1040                .0,
1041            second_id
1042        );
1043    });
1044
1045    // first is reported as replaced
1046    let (reject_request, _) = requests.reject.next().await.unwrap();
1047
1048    assert_eq!(
1049        &reject_request.rejections,
1050        &[EditPredictionRejection {
1051            request_id: first_id,
1052            reason: EditPredictionRejectReason::Replaced,
1053            was_shown: false
1054        }]
1055    );
1056}
1057
1058#[gpui::test]
1059async fn test_current_preferred(cx: &mut TestAppContext) {
1060    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1061    let fs = FakeFs::new(cx.executor());
1062    fs.insert_tree(
1063        "/root",
1064        json!({
1065            "foo.md":  "Hello!\nHow\nBye\n"
1066        }),
1067    )
1068    .await;
1069    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1070
1071    let buffer = project
1072        .update(cx, |project, cx| {
1073            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1074            project.open_buffer(path, cx)
1075        })
1076        .await
1077        .unwrap();
1078    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1079    let position = snapshot.anchor_before(language::Point::new(1, 3));
1080
1081    ep_store.update(cx, |ep_store, cx| {
1082        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1083    });
1084
1085    let (request, respond_tx) = requests.predict.next().await.unwrap();
1086    let first_response = model_response(&request, SIMPLE_DIFF);
1087    let first_id = first_response.request_id.clone();
1088    respond_tx.send(first_response).unwrap();
1089
1090    cx.run_until_parked();
1091
1092    ep_store.update(cx, |ep_store, cx| {
1093        assert_eq!(
1094            ep_store
1095                .prediction_at(&buffer, None, &project, cx)
1096                .unwrap()
1097                .id
1098                .0,
1099            first_id
1100        );
1101    });
1102
1103    // a second request is triggered
1104    ep_store.update(cx, |ep_store, cx| {
1105        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1106    });
1107
1108    let (request, respond_tx) = requests.predict.next().await.unwrap();
1109    // worse than current prediction
1110    let second_response = model_response(
1111        &request,
1112        indoc! { r"
1113            --- a/root/foo.md
1114            +++ b/root/foo.md
1115            @@ ... @@
1116             Hello!
1117            -How
1118            +How are
1119             Bye
1120        "},
1121    );
1122    let second_id = second_response.request_id.clone();
1123    respond_tx.send(second_response).unwrap();
1124
1125    cx.run_until_parked();
1126
1127    ep_store.update(cx, |ep_store, cx| {
1128        // first is preferred over second
1129        assert_eq!(
1130            ep_store
1131                .prediction_at(&buffer, None, &project, cx)
1132                .unwrap()
1133                .id
1134                .0,
1135            first_id
1136        );
1137    });
1138
1139    // second is reported as rejected
1140    let (reject_request, _) = requests.reject.next().await.unwrap();
1141
1142    assert_eq!(
1143        &reject_request.rejections,
1144        &[EditPredictionRejection {
1145            request_id: second_id,
1146            reason: EditPredictionRejectReason::CurrentPreferred,
1147            was_shown: false
1148        }]
1149    );
1150}
1151
1152#[gpui::test]
1153async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1154    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1155    let fs = FakeFs::new(cx.executor());
1156    fs.insert_tree(
1157        "/root",
1158        json!({
1159            "foo.md":  "Hello!\nHow\nBye\n"
1160        }),
1161    )
1162    .await;
1163    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1164
1165    let buffer = project
1166        .update(cx, |project, cx| {
1167            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1168            project.open_buffer(path, cx)
1169        })
1170        .await
1171        .unwrap();
1172    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1173    let position = snapshot.anchor_before(language::Point::new(1, 3));
1174
1175    // start two refresh tasks
1176    ep_store.update(cx, |ep_store, cx| {
1177        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1178    });
1179
1180    let (request1, respond_first) = requests.predict.next().await.unwrap();
1181
1182    ep_store.update(cx, |ep_store, cx| {
1183        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1184    });
1185
1186    let (request, respond_second) = requests.predict.next().await.unwrap();
1187
1188    // wait for throttle
1189    cx.run_until_parked();
1190
1191    // second responds first
1192    let second_response = model_response(&request, SIMPLE_DIFF);
1193    let second_id = second_response.request_id.clone();
1194    respond_second.send(second_response).unwrap();
1195
1196    cx.run_until_parked();
1197
1198    ep_store.update(cx, |ep_store, cx| {
1199        // current prediction is second
1200        assert_eq!(
1201            ep_store
1202                .prediction_at(&buffer, None, &project, cx)
1203                .unwrap()
1204                .id
1205                .0,
1206            second_id
1207        );
1208    });
1209
1210    let first_response = model_response(&request1, SIMPLE_DIFF);
1211    let first_id = first_response.request_id.clone();
1212    respond_first.send(first_response).unwrap();
1213
1214    cx.run_until_parked();
1215
1216    ep_store.update(cx, |ep_store, cx| {
1217        // current prediction is still second, since first was cancelled
1218        assert_eq!(
1219            ep_store
1220                .prediction_at(&buffer, None, &project, cx)
1221                .unwrap()
1222                .id
1223                .0,
1224            second_id
1225        );
1226    });
1227
1228    // first is reported as rejected
1229    let (reject_request, _) = requests.reject.next().await.unwrap();
1230
1231    cx.run_until_parked();
1232
1233    assert_eq!(
1234        &reject_request.rejections,
1235        &[EditPredictionRejection {
1236            request_id: first_id,
1237            reason: EditPredictionRejectReason::Canceled,
1238            was_shown: false
1239        }]
1240    );
1241}
1242
1243#[gpui::test]
1244async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1245    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1246    let fs = FakeFs::new(cx.executor());
1247    fs.insert_tree(
1248        "/root",
1249        json!({
1250            "foo.md":  "Hello!\nHow\nBye\n"
1251        }),
1252    )
1253    .await;
1254    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1255
1256    let buffer = project
1257        .update(cx, |project, cx| {
1258            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1259            project.open_buffer(path, cx)
1260        })
1261        .await
1262        .unwrap();
1263    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1264    let position = snapshot.anchor_before(language::Point::new(1, 3));
1265
1266    // start two refresh tasks
1267    ep_store.update(cx, |ep_store, cx| {
1268        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1269    });
1270
1271    let (request1, respond_first) = requests.predict.next().await.unwrap();
1272
1273    ep_store.update(cx, |ep_store, cx| {
1274        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1275    });
1276
1277    let (request2, respond_second) = requests.predict.next().await.unwrap();
1278
1279    // wait for throttle, so requests are sent
1280    cx.run_until_parked();
1281
1282    ep_store.update(cx, |ep_store, cx| {
1283        // start a third request
1284        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1285
1286        // 2 are pending, so 2nd is cancelled
1287        assert_eq!(
1288            ep_store
1289                .get_or_init_project(&project, cx)
1290                .cancelled_predictions
1291                .iter()
1292                .copied()
1293                .collect::<Vec<_>>(),
1294            [1]
1295        );
1296    });
1297
1298    // wait for throttle
1299    cx.run_until_parked();
1300
1301    let (request3, respond_third) = requests.predict.next().await.unwrap();
1302
1303    let first_response = model_response(&request1, SIMPLE_DIFF);
1304    let first_id = first_response.request_id.clone();
1305    respond_first.send(first_response).unwrap();
1306
1307    cx.run_until_parked();
1308
1309    ep_store.update(cx, |ep_store, cx| {
1310        // current prediction is first
1311        assert_eq!(
1312            ep_store
1313                .prediction_at(&buffer, None, &project, cx)
1314                .unwrap()
1315                .id
1316                .0,
1317            first_id
1318        );
1319    });
1320
1321    let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1322    let cancelled_id = cancelled_response.request_id.clone();
1323    respond_second.send(cancelled_response).unwrap();
1324
1325    cx.run_until_parked();
1326
1327    ep_store.update(cx, |ep_store, cx| {
1328        // current prediction is still first, since second was cancelled
1329        assert_eq!(
1330            ep_store
1331                .prediction_at(&buffer, None, &project, cx)
1332                .unwrap()
1333                .id
1334                .0,
1335            first_id
1336        );
1337    });
1338
1339    let third_response = model_response(&request3, SIMPLE_DIFF);
1340    let third_response_id = third_response.request_id.clone();
1341    respond_third.send(third_response).unwrap();
1342
1343    cx.run_until_parked();
1344
1345    ep_store.update(cx, |ep_store, cx| {
1346        // third completes and replaces first
1347        assert_eq!(
1348            ep_store
1349                .prediction_at(&buffer, None, &project, cx)
1350                .unwrap()
1351                .id
1352                .0,
1353            third_response_id
1354        );
1355    });
1356
1357    // second is reported as rejected
1358    let (reject_request, _) = requests.reject.next().await.unwrap();
1359
1360    cx.run_until_parked();
1361
1362    assert_eq!(
1363        &reject_request.rejections,
1364        &[
1365            EditPredictionRejection {
1366                request_id: cancelled_id,
1367                reason: EditPredictionRejectReason::Canceled,
1368                was_shown: false
1369            },
1370            EditPredictionRejection {
1371                request_id: first_id,
1372                reason: EditPredictionRejectReason::Replaced,
1373                was_shown: false
1374            }
1375        ]
1376    );
1377}
1378
1379#[gpui::test]
1380async fn test_rejections_flushing(cx: &mut TestAppContext) {
1381    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1382
1383    ep_store.update(cx, |ep_store, cx| {
1384        ep_store.reject_prediction(
1385            EditPredictionId("test-1".into()),
1386            EditPredictionRejectReason::Discarded,
1387            false,
1388            cx,
1389        );
1390        ep_store.reject_prediction(
1391            EditPredictionId("test-2".into()),
1392            EditPredictionRejectReason::Canceled,
1393            true,
1394            cx,
1395        );
1396    });
1397
1398    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1399    cx.run_until_parked();
1400
1401    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1402    respond_tx.send(()).unwrap();
1403
1404    // batched
1405    assert_eq!(reject_request.rejections.len(), 2);
1406    assert_eq!(
1407        reject_request.rejections[0],
1408        EditPredictionRejection {
1409            request_id: "test-1".to_string(),
1410            reason: EditPredictionRejectReason::Discarded,
1411            was_shown: false
1412        }
1413    );
1414    assert_eq!(
1415        reject_request.rejections[1],
1416        EditPredictionRejection {
1417            request_id: "test-2".to_string(),
1418            reason: EditPredictionRejectReason::Canceled,
1419            was_shown: true
1420        }
1421    );
1422
1423    // Reaching batch size limit sends without debounce
1424    ep_store.update(cx, |ep_store, cx| {
1425        for i in 0..70 {
1426            ep_store.reject_prediction(
1427                EditPredictionId(format!("batch-{}", i).into()),
1428                EditPredictionRejectReason::Discarded,
1429                false,
1430                cx,
1431            );
1432        }
1433    });
1434
1435    // First MAX/2 items are sent immediately
1436    cx.run_until_parked();
1437    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1438    respond_tx.send(()).unwrap();
1439
1440    assert_eq!(reject_request.rejections.len(), 50);
1441    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1442    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1443
1444    // Remaining items are debounced with the next batch
1445    cx.executor().advance_clock(Duration::from_secs(15));
1446    cx.run_until_parked();
1447
1448    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1449    respond_tx.send(()).unwrap();
1450
1451    assert_eq!(reject_request.rejections.len(), 20);
1452    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1453    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1454
1455    // Request failure
1456    ep_store.update(cx, |ep_store, cx| {
1457        ep_store.reject_prediction(
1458            EditPredictionId("retry-1".into()),
1459            EditPredictionRejectReason::Discarded,
1460            false,
1461            cx,
1462        );
1463    });
1464
1465    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1466    cx.run_until_parked();
1467
1468    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1469    assert_eq!(reject_request.rejections.len(), 1);
1470    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1471    // Simulate failure
1472    drop(_respond_tx);
1473
1474    // Add another rejection
1475    ep_store.update(cx, |ep_store, cx| {
1476        ep_store.reject_prediction(
1477            EditPredictionId("retry-2".into()),
1478            EditPredictionRejectReason::Discarded,
1479            false,
1480            cx,
1481        );
1482    });
1483
1484    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1485    cx.run_until_parked();
1486
1487    // Retry should include both the failed item and the new one
1488    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1489    respond_tx.send(()).unwrap();
1490
1491    assert_eq!(reject_request.rejections.len(), 2);
1492    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1493    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1494}
1495
1496// Skipped until we start including diagnostics in prompt
1497// #[gpui::test]
1498// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1499//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1500//     let fs = FakeFs::new(cx.executor());
1501//     fs.insert_tree(
1502//         "/root",
1503//         json!({
1504//             "foo.md": "Hello!\nBye"
1505//         }),
1506//     )
1507//     .await;
1508//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1509
1510//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1511//     let diagnostic = lsp::Diagnostic {
1512//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1513//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1514//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1515//         ..Default::default()
1516//     };
1517
1518//     project.update(cx, |project, cx| {
1519//         project.lsp_store().update(cx, |lsp_store, cx| {
1520//             // Create some diagnostics
1521//             lsp_store
1522//                 .update_diagnostics(
1523//                     LanguageServerId(0),
1524//                     lsp::PublishDiagnosticsParams {
1525//                         uri: path_to_buffer_uri.clone(),
1526//                         diagnostics: vec![diagnostic],
1527//                         version: None,
1528//                     },
1529//                     None,
1530//                     language::DiagnosticSourceKind::Pushed,
1531//                     &[],
1532//                     cx,
1533//                 )
1534//                 .unwrap();
1535//         });
1536//     });
1537
1538//     let buffer = project
1539//         .update(cx, |project, cx| {
1540//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1541//             project.open_buffer(path, cx)
1542//         })
1543//         .await
1544//         .unwrap();
1545
1546//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1547//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1548
1549//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1550//         ep_store.request_prediction(&project, &buffer, position, cx)
1551//     });
1552
1553//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1554
1555//     assert_eq!(request.diagnostic_groups.len(), 1);
1556//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1557//         .unwrap();
1558//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1559//     assert_eq!(
1560//         value,
1561//         json!({
1562//             "entries": [{
1563//                 "range": {
1564//                     "start": 8,
1565//                     "end": 10
1566//                 },
1567//                 "diagnostic": {
1568//                     "source": null,
1569//                     "code": null,
1570//                     "code_description": null,
1571//                     "severity": 1,
1572//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1573//                     "markdown": null,
1574//                     "group_id": 0,
1575//                     "is_primary": true,
1576//                     "is_disk_based": false,
1577//                     "is_unnecessary": false,
1578//                     "source_kind": "Pushed",
1579//                     "data": null,
1580//                     "underline": true
1581//                 }
1582//             }],
1583//             "primary_ix": 0
1584//         })
1585//     );
1586// }
1587
1588// Generate a model response that would apply the given diff to the active file.
1589fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1590    let excerpt =
1591        request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1592    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1593
1594    PredictEditsV3Response {
1595        request_id: Uuid::new_v4().to_string(),
1596        output: new_excerpt,
1597    }
1598}
1599
1600fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1601    zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1602}
1603
1604struct RequestChannels {
1605    predict: mpsc::UnboundedReceiver<(
1606        PredictEditsV3Request,
1607        oneshot::Sender<PredictEditsV3Response>,
1608    )>,
1609    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1610}
1611
1612fn init_test_with_fake_client(
1613    cx: &mut TestAppContext,
1614) -> (Entity<EditPredictionStore>, RequestChannels) {
1615    cx.update(move |cx| {
1616        let settings_store = SettingsStore::test(cx);
1617        cx.set_global(settings_store);
1618        zlog::init_test();
1619
1620        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1621        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1622
1623        let http_client = FakeHttpClient::create({
1624            move |req| {
1625                let uri = req.uri().path().to_string();
1626                let mut body = req.into_body();
1627                let predict_req_tx = predict_req_tx.clone();
1628                let reject_req_tx = reject_req_tx.clone();
1629                async move {
1630                    let resp = match uri.as_str() {
1631                        "/client/llm_tokens" => serde_json::to_string(&json!({
1632                            "token": "test"
1633                        }))
1634                        .unwrap(),
1635                        "/predict_edits/v3" => {
1636                            let mut buf = Vec::new();
1637                            body.read_to_end(&mut buf).await.ok();
1638                            let decompressed = zstd::decode_all(&buf[..]).unwrap();
1639                            let req = serde_json::from_slice(&decompressed).unwrap();
1640
1641                            let (res_tx, res_rx) = oneshot::channel();
1642                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1643                            serde_json::to_string(&res_rx.await?).unwrap()
1644                        }
1645                        "/predict_edits/reject" => {
1646                            let mut buf = Vec::new();
1647                            body.read_to_end(&mut buf).await.ok();
1648                            let req = serde_json::from_slice(&buf).unwrap();
1649
1650                            let (res_tx, res_rx) = oneshot::channel();
1651                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1652                            serde_json::to_string(&res_rx.await?).unwrap()
1653                        }
1654                        _ => {
1655                            panic!("Unexpected path: {}", uri)
1656                        }
1657                    };
1658
1659                    Ok(Response::builder().body(resp.into()).unwrap())
1660                }
1661            }
1662        });
1663
1664        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1665        client.cloud_client().set_credentials(1, "test".into());
1666
1667        language_model::init(client.clone(), cx);
1668
1669        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1670        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1671
1672        (
1673            ep_store,
1674            RequestChannels {
1675                predict: predict_req_rx,
1676                reject: reject_req_rx,
1677            },
1678        )
1679    })
1680}
1681
1682const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1683
1684#[gpui::test]
1685async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1686    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1687    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1688        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1689    });
1690
1691    let edit_preview = cx
1692        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1693        .await;
1694
1695    let prediction = EditPrediction {
1696        edits,
1697        cursor_position: None,
1698        edit_preview,
1699        buffer: buffer.clone(),
1700        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1701        id: EditPredictionId("the-id".into()),
1702        inputs: ZetaPromptInput {
1703            events: Default::default(),
1704            related_files: Default::default(),
1705            cursor_path: Path::new("").into(),
1706            cursor_excerpt: "".into(),
1707            editable_range_in_excerpt: 0..0,
1708            cursor_offset_in_excerpt: 0,
1709            excerpt_start_row: None,
1710        },
1711        buffer_snapshotted_at: Instant::now(),
1712        response_received_at: Instant::now(),
1713    };
1714
1715    cx.update(|cx| {
1716        assert_eq!(
1717            from_completion_edits(
1718                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1719                &buffer,
1720                cx
1721            ),
1722            vec![(2..5, "REM".into()), (9..11, "".into())]
1723        );
1724
1725        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1726        assert_eq!(
1727            from_completion_edits(
1728                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1729                &buffer,
1730                cx
1731            ),
1732            vec![(2..2, "REM".into()), (6..8, "".into())]
1733        );
1734
1735        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1736        assert_eq!(
1737            from_completion_edits(
1738                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1739                &buffer,
1740                cx
1741            ),
1742            vec![(2..5, "REM".into()), (9..11, "".into())]
1743        );
1744
1745        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1746        assert_eq!(
1747            from_completion_edits(
1748                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1749                &buffer,
1750                cx
1751            ),
1752            vec![(3..3, "EM".into()), (7..9, "".into())]
1753        );
1754
1755        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1756        assert_eq!(
1757            from_completion_edits(
1758                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1759                &buffer,
1760                cx
1761            ),
1762            vec![(4..4, "M".into()), (8..10, "".into())]
1763        );
1764
1765        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1766        assert_eq!(
1767            from_completion_edits(
1768                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1769                &buffer,
1770                cx
1771            ),
1772            vec![(9..11, "".into())]
1773        );
1774
1775        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1776        assert_eq!(
1777            from_completion_edits(
1778                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1779                &buffer,
1780                cx
1781            ),
1782            vec![(4..4, "M".into()), (8..10, "".into())]
1783        );
1784
1785        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1786        assert_eq!(
1787            from_completion_edits(
1788                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1789                &buffer,
1790                cx
1791            ),
1792            vec![(4..4, "M".into())]
1793        );
1794
1795        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1796        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1797    })
1798}
1799
1800#[gpui::test]
1801async fn test_clean_up_diff(cx: &mut TestAppContext) {
1802    init_test(cx);
1803
1804    assert_eq!(
1805        apply_edit_prediction(
1806            indoc! {"
1807                    fn main() {
1808                        let word_1 = \"lorem\";
1809                        let range = word.len()..word.len();
1810                    }
1811                "},
1812            indoc! {"
1813                    <|editable_region_start|>
1814                    fn main() {
1815                        let word_1 = \"lorem\";
1816                        let range = word_1.len()..word_1.len();
1817                    }
1818
1819                    <|editable_region_end|>
1820                "},
1821            cx,
1822        )
1823        .await,
1824        indoc! {"
1825                fn main() {
1826                    let word_1 = \"lorem\";
1827                    let range = word_1.len()..word_1.len();
1828                }
1829            "},
1830    );
1831
1832    assert_eq!(
1833        apply_edit_prediction(
1834            indoc! {"
1835                    fn main() {
1836                        let story = \"the quick\"
1837                    }
1838                "},
1839            indoc! {"
1840                    <|editable_region_start|>
1841                    fn main() {
1842                        let story = \"the quick brown fox jumps over the lazy dog\";
1843                    }
1844
1845                    <|editable_region_end|>
1846                "},
1847            cx,
1848        )
1849        .await,
1850        indoc! {"
1851                fn main() {
1852                    let story = \"the quick brown fox jumps over the lazy dog\";
1853                }
1854            "},
1855    );
1856}
1857
1858#[gpui::test]
1859async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1860    init_test(cx);
1861
1862    let buffer_content = "lorem\n";
1863    let completion_response = indoc! {"
1864            ```animals.js
1865            <|start_of_file|>
1866            <|editable_region_start|>
1867            lorem
1868            ipsum
1869            <|editable_region_end|>
1870            ```"};
1871
1872    assert_eq!(
1873        apply_edit_prediction(buffer_content, completion_response, cx).await,
1874        "lorem\nipsum"
1875    );
1876}
1877
1878#[gpui::test]
1879async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1880    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1881    // When the buffer ends without a trailing newline, but the model returns output
1882    // with a trailing newline, zeta2 should normalize both sides before diffing
1883    // so no spurious newline is inserted.
1884    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1885    let fs = FakeFs::new(cx.executor());
1886
1887    // Single line buffer with no trailing newline
1888    fs.insert_tree(
1889        "/root",
1890        json!({
1891            "foo.txt": "hello"
1892        }),
1893    )
1894    .await;
1895    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1896
1897    let buffer = project
1898        .update(cx, |project, cx| {
1899            let path = project
1900                .find_project_path(path!("root/foo.txt"), cx)
1901                .unwrap();
1902            project.open_buffer(path, cx)
1903        })
1904        .await
1905        .unwrap();
1906
1907    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1908    let position = snapshot.anchor_before(language::Point::new(0, 5));
1909
1910    ep_store.update(cx, |ep_store, cx| {
1911        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1912    });
1913
1914    let (_request, respond_tx) = requests.predict.next().await.unwrap();
1915
1916    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1917    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1918    let response = PredictEditsV3Response {
1919        request_id: Uuid::new_v4().to_string(),
1920        output: "hello world\n".to_string(),
1921    };
1922    respond_tx.send(response).unwrap();
1923
1924    cx.run_until_parked();
1925
1926    // The prediction should insert " world" without adding a newline
1927    ep_store.update(cx, |ep_store, cx| {
1928        let prediction = ep_store
1929            .prediction_at(&buffer, None, &project, cx)
1930            .expect("should have prediction");
1931        let edits: Vec<_> = prediction
1932            .edits
1933            .iter()
1934            .map(|(range, text)| {
1935                let snapshot = buffer.read(cx).snapshot();
1936                (range.to_offset(&snapshot), text.clone())
1937            })
1938            .collect();
1939        assert_eq!(edits, vec![(5..5, " world".into())]);
1940    });
1941}
1942
1943#[gpui::test]
1944async fn test_can_collect_data(cx: &mut TestAppContext) {
1945    init_test(cx);
1946
1947    let fs = project::FakeFs::new(cx.executor());
1948    fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1949        .await;
1950
1951    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1952    let buffer = project
1953        .update(cx, |project, cx| {
1954            project.open_local_buffer(path!("/project/src/main.rs"), cx)
1955        })
1956        .await
1957        .unwrap();
1958
1959    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1960    ep_store.update(cx, |ep_store, _cx| {
1961        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1962    });
1963
1964    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1965    assert_eq!(
1966        captured_request.lock().clone().unwrap().can_collect_data,
1967        true
1968    );
1969
1970    ep_store.update(cx, |ep_store, _cx| {
1971        ep_store.data_collection_choice = DataCollectionChoice::Disabled
1972    });
1973
1974    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1975    assert_eq!(
1976        captured_request.lock().clone().unwrap().can_collect_data,
1977        false
1978    );
1979}
1980
1981#[gpui::test]
1982async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1983    init_test(cx);
1984
1985    let fs = project::FakeFs::new(cx.executor());
1986    let project = Project::test(fs.clone(), [], cx).await;
1987
1988    let buffer = cx.new(|_cx| {
1989        Buffer::remote(
1990            language::BufferId::new(1).unwrap(),
1991            ReplicaId::new(1),
1992            language::Capability::ReadWrite,
1993            "fn main() {\n    println!(\"Hello\");\n}",
1994        )
1995    });
1996
1997    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1998    ep_store.update(cx, |ep_store, _cx| {
1999        ep_store.data_collection_choice = DataCollectionChoice::Enabled
2000    });
2001
2002    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2003    assert_eq!(
2004        captured_request.lock().clone().unwrap().can_collect_data,
2005        false
2006    );
2007}
2008
2009#[gpui::test]
2010async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
2011    init_test(cx);
2012
2013    let fs = project::FakeFs::new(cx.executor());
2014    fs.insert_tree(
2015        path!("/project"),
2016        json!({
2017            "LICENSE": BSD_0_TXT,
2018            ".env": "SECRET_KEY=secret"
2019        }),
2020    )
2021    .await;
2022
2023    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2024    let buffer = project
2025        .update(cx, |project, cx| {
2026            project.open_local_buffer("/project/.env", cx)
2027        })
2028        .await
2029        .unwrap();
2030
2031    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
2032    ep_store.update(cx, |ep_store, _cx| {
2033        ep_store.data_collection_choice = DataCollectionChoice::Enabled
2034    });
2035
2036    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2037    assert_eq!(
2038        captured_request.lock().clone().unwrap().can_collect_data,
2039        false
2040    );
2041}
2042
2043#[gpui::test]
2044async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
2045    init_test(cx);
2046
2047    let fs = project::FakeFs::new(cx.executor());
2048    let project = Project::test(fs.clone(), [], cx).await;
2049    let buffer = cx.new(|cx| Buffer::local("", cx));
2050
2051    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
2052    ep_store.update(cx, |ep_store, _cx| {
2053        ep_store.data_collection_choice = DataCollectionChoice::Enabled
2054    });
2055
2056    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2057    assert_eq!(
2058        captured_request.lock().clone().unwrap().can_collect_data,
2059        false
2060    );
2061}
2062
2063#[gpui::test]
2064async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
2065    init_test(cx);
2066
2067    let fs = project::FakeFs::new(cx.executor());
2068    fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
2069        .await;
2070
2071    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2072    let buffer = project
2073        .update(cx, |project, cx| {
2074            project.open_local_buffer("/project/main.rs", cx)
2075        })
2076        .await
2077        .unwrap();
2078
2079    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
2080    ep_store.update(cx, |ep_store, _cx| {
2081        ep_store.data_collection_choice = DataCollectionChoice::Enabled
2082    });
2083
2084    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2085    assert_eq!(
2086        captured_request.lock().clone().unwrap().can_collect_data,
2087        false
2088    );
2089}
2090
2091#[gpui::test]
2092async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
2093    init_test(cx);
2094
2095    let fs = project::FakeFs::new(cx.executor());
2096    fs.insert_tree(
2097        path!("/open_source_worktree"),
2098        json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
2099    )
2100    .await;
2101    fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
2102        .await;
2103
2104    let project = Project::test(
2105        fs.clone(),
2106        [
2107            path!("/open_source_worktree").as_ref(),
2108            path!("/closed_source_worktree").as_ref(),
2109        ],
2110        cx,
2111    )
2112    .await;
2113    let buffer = project
2114        .update(cx, |project, cx| {
2115            project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
2116        })
2117        .await
2118        .unwrap();
2119
2120    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
2121    ep_store.update(cx, |ep_store, _cx| {
2122        ep_store.data_collection_choice = DataCollectionChoice::Enabled
2123    });
2124
2125    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2126    assert_eq!(
2127        captured_request.lock().clone().unwrap().can_collect_data,
2128        true
2129    );
2130
2131    let closed_source_file = project
2132        .update(cx, |project, cx| {
2133            let worktree2 = project
2134                .worktree_for_root_name("closed_source_worktree", cx)
2135                .unwrap();
2136            worktree2.update(cx, |worktree2, cx| {
2137                worktree2.load_file(rel_path("main.rs"), cx)
2138            })
2139        })
2140        .await
2141        .unwrap()
2142        .file;
2143
2144    buffer.update(cx, |buffer, cx| {
2145        buffer.file_updated(closed_source_file, cx);
2146    });
2147
2148    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2149    assert_eq!(
2150        captured_request.lock().clone().unwrap().can_collect_data,
2151        false
2152    );
2153}
2154
2155#[gpui::test]
2156async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
2157    init_test(cx);
2158
2159    let fs = project::FakeFs::new(cx.executor());
2160    fs.insert_tree(
2161        path!("/worktree1"),
2162        json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
2163    )
2164    .await;
2165    fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
2166        .await;
2167
2168    let project = Project::test(
2169        fs.clone(),
2170        [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
2171        cx,
2172    )
2173    .await;
2174    let buffer = project
2175        .update(cx, |project, cx| {
2176            project.open_local_buffer(path!("/worktree1/main.rs"), cx)
2177        })
2178        .await
2179        .unwrap();
2180    let private_buffer = project
2181        .update(cx, |project, cx| {
2182            project.open_local_buffer(path!("/worktree2/file.rs"), cx)
2183        })
2184        .await
2185        .unwrap();
2186
2187    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
2188    ep_store.update(cx, |ep_store, _cx| {
2189        ep_store.data_collection_choice = DataCollectionChoice::Enabled
2190    });
2191
2192    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2193    assert_eq!(
2194        captured_request.lock().clone().unwrap().can_collect_data,
2195        true
2196    );
2197
2198    // this has a side effect of registering the buffer to watch for edits
2199    run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
2200    assert_eq!(
2201        captured_request.lock().clone().unwrap().can_collect_data,
2202        false
2203    );
2204
2205    private_buffer.update(cx, |private_buffer, cx| {
2206        private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
2207    });
2208
2209    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2210    assert_eq!(
2211        captured_request.lock().clone().unwrap().can_collect_data,
2212        false
2213    );
2214
2215    // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
2216    // included
2217    buffer.update(cx, |buffer, cx| {
2218        buffer.edit(
2219            [(
2220                0..0,
2221                " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
2222            )],
2223            None,
2224            cx,
2225        );
2226    });
2227
2228    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2229    assert_eq!(
2230        captured_request.lock().clone().unwrap().can_collect_data,
2231        true
2232    );
2233}
2234
2235fn init_test(cx: &mut TestAppContext) {
2236    cx.update(|cx| {
2237        let settings_store = SettingsStore::test(cx);
2238        cx.set_global(settings_store);
2239    });
2240}
2241
2242async fn apply_edit_prediction(
2243    buffer_content: &str,
2244    completion_response: &str,
2245    cx: &mut TestAppContext,
2246) -> String {
2247    let fs = project::FakeFs::new(cx.executor());
2248    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2249    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2250    let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
2251    *response.lock() = completion_response.to_string();
2252    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2253    buffer.update(cx, |buffer, cx| {
2254        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2255    });
2256    buffer.read_with(cx, |buffer, _| buffer.text())
2257}
2258
2259async fn run_edit_prediction(
2260    buffer: &Entity<Buffer>,
2261    project: &Entity<Project>,
2262    ep_store: &Entity<EditPredictionStore>,
2263    cx: &mut TestAppContext,
2264) -> EditPrediction {
2265    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2266    ep_store.update(cx, |ep_store, cx| {
2267        ep_store.register_buffer(buffer, &project, cx)
2268    });
2269    cx.background_executor.run_until_parked();
2270    let prediction_task = ep_store.update(cx, |ep_store, cx| {
2271        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2272    });
2273    prediction_task.await.unwrap().unwrap().prediction.unwrap()
2274}
2275
2276async fn make_test_ep_store(
2277    project: &Entity<Project>,
2278    cx: &mut TestAppContext,
2279) -> (
2280    Entity<EditPredictionStore>,
2281    Arc<Mutex<Option<PredictEditsBody>>>,
2282    Arc<Mutex<String>>,
2283) {
2284    let default_response = indoc! {"
2285            ```main.rs
2286            <|start_of_file|>
2287            <|editable_region_start|>
2288            hello world
2289            <|editable_region_end|>
2290            ```"
2291    };
2292    let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2293    let completion_response: Arc<Mutex<String>> =
2294        Arc::new(Mutex::new(default_response.to_string()));
2295    let http_client = FakeHttpClient::create({
2296        let captured_request = captured_request.clone();
2297        let completion_response = completion_response.clone();
2298        let mut next_request_id = 0;
2299        move |req| {
2300            let captured_request = captured_request.clone();
2301            let completion_response = completion_response.clone();
2302            async move {
2303                match (req.method(), req.uri().path()) {
2304                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2305                        .status(200)
2306                        .body(
2307                            serde_json::to_string(&CreateLlmTokenResponse {
2308                                token: LlmToken("the-llm-token".to_string()),
2309                            })
2310                            .unwrap()
2311                            .into(),
2312                        )
2313                        .unwrap()),
2314                    (&Method::POST, "/predict_edits/v2") => {
2315                        let mut request_body = String::new();
2316                        req.into_body().read_to_string(&mut request_body).await?;
2317                        *captured_request.lock() =
2318                            Some(serde_json::from_str(&request_body).unwrap());
2319                        next_request_id += 1;
2320                        Ok(http_client::Response::builder()
2321                            .status(200)
2322                            .body(
2323                                serde_json::to_string(&PredictEditsResponse {
2324                                    request_id: format!("request-{next_request_id}"),
2325                                    output_excerpt: completion_response.lock().clone(),
2326                                })
2327                                .unwrap()
2328                                .into(),
2329                            )
2330                            .unwrap())
2331                    }
2332                    (&Method::POST, "/predict_edits/v3") => {
2333                        next_request_id += 1;
2334                        Ok(http_client::Response::builder()
2335                            .status(200)
2336                            .body(
2337                                serde_json::to_string(&PredictEditsV3Response {
2338                                    request_id: format!("request-{next_request_id}"),
2339                                    output: "hello world".to_string(),
2340                                })
2341                                .unwrap()
2342                                .into(),
2343                            )
2344                            .unwrap())
2345                    }
2346                    _ => Ok(http_client::Response::builder()
2347                        .status(404)
2348                        .body("Not Found".into())
2349                        .unwrap()),
2350                }
2351            }
2352        }
2353    });
2354
2355    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2356    cx.update(|cx| {
2357        RefreshLlmTokenListener::register(client.clone(), cx);
2358    });
2359    let _server = FakeServer::for_client(42, &client, cx).await;
2360
2361    let ep_store = cx.new(|cx| {
2362        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2363        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2364
2365        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2366        for worktree in worktrees {
2367            let worktree_id = worktree.read(cx).id();
2368            ep_store
2369                .get_or_init_project(project, cx)
2370                .license_detection_watchers
2371                .entry(worktree_id)
2372                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2373        }
2374
2375        ep_store
2376    });
2377
2378    (ep_store, captured_request, completion_response)
2379}
2380
2381fn to_completion_edits(
2382    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2383    buffer: &Entity<Buffer>,
2384    cx: &App,
2385) -> Vec<(Range<Anchor>, Arc<str>)> {
2386    let buffer = buffer.read(cx);
2387    iterator
2388        .into_iter()
2389        .map(|(range, text)| {
2390            (
2391                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2392                text,
2393            )
2394        })
2395        .collect()
2396}
2397
2398fn from_completion_edits(
2399    editor_edits: &[(Range<Anchor>, Arc<str>)],
2400    buffer: &Entity<Buffer>,
2401    cx: &App,
2402) -> Vec<(Range<usize>, Arc<str>)> {
2403    let buffer = buffer.read(cx);
2404    editor_edits
2405        .iter()
2406        .map(|(range, text)| {
2407            (
2408                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2409                text.clone(),
2410            )
2411        })
2412        .collect()
2413}
2414
2415#[gpui::test]
2416async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2417    init_test(cx);
2418
2419    let fs = FakeFs::new(cx.executor());
2420    fs.insert_tree(
2421        "/project",
2422        serde_json::json!({
2423            "main.rs": "fn main() {\n    \n}\n"
2424        }),
2425    )
2426    .await;
2427
2428    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2429
2430    let http_client = FakeHttpClient::create(|_req| async move {
2431        Ok(gpui::http_client::Response::builder()
2432            .status(401)
2433            .body("Unauthorized".into())
2434            .unwrap())
2435    });
2436
2437    let client =
2438        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2439    cx.update(|cx| {
2440        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2441    });
2442
2443    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2444
2445    let buffer = project
2446        .update(cx, |project, cx| {
2447            let path = project
2448                .find_project_path(path!("/project/main.rs"), cx)
2449                .unwrap();
2450            project.open_buffer(path, cx)
2451        })
2452        .await
2453        .unwrap();
2454
2455    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2456    ep_store.update(cx, |ep_store, cx| {
2457        ep_store.register_buffer(&buffer, &project, cx)
2458    });
2459    cx.background_executor.run_until_parked();
2460
2461    let completion_task = ep_store.update(cx, |ep_store, cx| {
2462        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2463        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2464    });
2465
2466    let result = completion_task.await;
2467    assert!(
2468        result.is_err(),
2469        "Without authentication and without custom URL, prediction should fail"
2470    );
2471}
2472
2473#[gpui::test]
2474fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2475    let buffer = cx.new(|cx| {
2476        Buffer::local(
2477            indoc! {"
2478                zero
2479                one
2480                two
2481                three
2482                four
2483                five
2484                six
2485                seven
2486                eight
2487                nine
2488                ten
2489                eleven
2490                twelve
2491                thirteen
2492                fourteen
2493                fifteen
2494                sixteen
2495                seventeen
2496                eighteen
2497                nineteen
2498                twenty
2499                twenty-one
2500                twenty-two
2501                twenty-three
2502                twenty-four
2503            "},
2504            cx,
2505        )
2506    });
2507
2508    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2509
2510    buffer.update(cx, |buffer, cx| {
2511        let point = Point::new(12, 0);
2512        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2513        let point = Point::new(8, 0);
2514        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2515    });
2516
2517    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2518
2519    let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2520
2521    assert_eq!(
2522        diff,
2523        indoc! {"
2524            @@ -6,10 +6,12 @@
2525             five
2526             six
2527             seven
2528            +FIRST INSERTION
2529             eight
2530             nine
2531             ten
2532             eleven
2533            +SECOND INSERTION
2534             twelve
2535             thirteen
2536             fourteen
2537            "}
2538    );
2539}
2540
2541#[ctor::ctor]
2542fn init_logger() {
2543    zlog::init_test();
2544}