edit_prediction_tests.rs

   1use super::*;
   2use crate::udiff::apply_diff_to_string;
   3use client::{RefreshLlmTokenListener, UserStore, test::FakeServer};
   4use clock::FakeSystemClock;
   5use clock::ReplicaId;
   6use cloud_api_types::{
   7    CreateLlmTokenResponse, LlmToken, Organization, OrganizationConfiguration,
   8    OrganizationEditPredictionConfiguration, OrganizationId,
   9};
  10use cloud_llm_client::{
  11    EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
  12    predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
  13};
  14use db::AppDatabase;
  15use settings::EditPredictionDataCollectionChoice;
  16
  17use futures::{
  18    AsyncReadExt, FutureExt, StreamExt,
  19    channel::{mpsc, oneshot},
  20};
  21use gpui::App;
  22use gpui::{
  23    Entity, TestAppContext,
  24    http_client::{FakeHttpClient, Response},
  25};
  26use indoc::indoc;
  27use language::{
  28    Anchor, Buffer, Capability, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet,
  29    DiagnosticSeverity, Operation, Point, Selection, SelectionGoal,
  30};
  31
  32use lsp::LanguageServerId;
  33use parking_lot::Mutex;
  34use pretty_assertions::{assert_eq, assert_matches};
  35use project::{FakeFs, Project};
  36use serde_json::json;
  37use settings::SettingsStore;
  38use std::{ops::Range, path::Path, sync::Arc, time::Duration};
  39use util::{
  40    path,
  41    test::{TextRangeMarker, marked_text_ranges_by},
  42};
  43use uuid::Uuid;
  44use workspace::{AppState, CollaboratorId, MultiWorkspace};
  45use zeta_prompt::ZetaPromptInput;
  46
  47use crate::{
  48    BufferEditPrediction, EDIT_PREDICTION_SETTLED_QUIESCENCE, EditPredictionId,
  49    EditPredictionJumpsFeatureFlag, EditPredictionStore, REJECT_REQUEST_DEBOUNCE,
  50};
  51
  52#[gpui::test]
  53async fn test_current_state(cx: &mut TestAppContext) {
  54    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  55    let fs = FakeFs::new(cx.executor());
  56    fs.insert_tree(
  57        "/root",
  58        json!({
  59            "1.txt": "Hello!\nHow\nBye\n",
  60            "2.txt": "Hola!\nComo\nAdios\n"
  61        }),
  62    )
  63    .await;
  64    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  65
  66    let buffer1 = project
  67        .update(cx, |project, cx| {
  68            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  69            project.set_active_path(Some(path.clone()), cx);
  70            project.open_buffer(path, cx)
  71        })
  72        .await
  73        .unwrap();
  74    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  75    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  76
  77    ep_store.update(cx, |ep_store, cx| {
  78        ep_store.register_project(&project, cx);
  79        ep_store.register_buffer(&buffer1, &project, cx);
  80    });
  81
  82    // Prediction for current file
  83
  84    ep_store.update(cx, |ep_store, cx| {
  85        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  86    });
  87    let (request, respond_tx) = requests.predict.next().await.unwrap();
  88
  89    respond_tx
  90        .send(model_response(
  91            &request,
  92            indoc! {r"
  93                --- a/root/1.txt
  94                +++ b/root/1.txt
  95                @@ ... @@
  96                 Hello!
  97                -How
  98                +How are you?
  99                 Bye
 100            "},
 101        ))
 102        .unwrap();
 103
 104    cx.run_until_parked();
 105
 106    ep_store.update(cx, |ep_store, cx| {
 107        let prediction = ep_store
 108            .prediction_at(&buffer1, None, &project, cx)
 109            .unwrap();
 110        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 111    });
 112
 113    ep_store.update(cx, |ep_store, cx| {
 114        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
 115    });
 116
 117    // Prediction for diagnostic in another file
 118
 119    let diagnostic = lsp::Diagnostic {
 120        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 121        severity: Some(lsp::DiagnosticSeverity::ERROR),
 122        message: "Sentence is incomplete".to_string(),
 123        ..Default::default()
 124    };
 125
 126    project.update(cx, |project, cx| {
 127        project.lsp_store().update(cx, |lsp_store, cx| {
 128            lsp_store
 129                .update_diagnostics(
 130                    LanguageServerId(0),
 131                    lsp::PublishDiagnosticsParams {
 132                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 133                        diagnostics: vec![diagnostic],
 134                        version: None,
 135                    },
 136                    None,
 137                    language::DiagnosticSourceKind::Pushed,
 138                    &[],
 139                    cx,
 140                )
 141                .unwrap();
 142        });
 143    });
 144
 145    let (request, respond_tx) = requests.predict.next().await.unwrap();
 146    respond_tx
 147        .send(model_response(
 148            &request,
 149            indoc! {r#"
 150                --- a/root/2.txt
 151                +++ b/root/2.txt
 152                @@ ... @@
 153                 Hola!
 154                -Como
 155                +Como estas?
 156                 Adios
 157            "#},
 158        ))
 159        .unwrap();
 160    cx.run_until_parked();
 161
 162    ep_store.update(cx, |ep_store, cx| {
 163        let prediction = ep_store
 164            .prediction_at(&buffer1, None, &project, cx)
 165            .unwrap();
 166        assert_matches!(
 167            prediction,
 168            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 169        );
 170    });
 171
 172    let buffer2 = project
 173        .update(cx, |project, cx| {
 174            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 175            project.open_buffer(path, cx)
 176        })
 177        .await
 178        .unwrap();
 179
 180    ep_store.update(cx, |ep_store, cx| {
 181        let prediction = ep_store
 182            .prediction_at(&buffer2, None, &project, cx)
 183            .unwrap();
 184        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 185    });
 186}
 187
 188#[gpui::test]
 189async fn test_diagnostics_refresh_suppressed_while_following(cx: &mut TestAppContext) {
 190    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 191
 192    cx.update(|cx| {
 193        cx.update_flags(
 194            false,
 195            vec![EditPredictionJumpsFeatureFlag::NAME.to_string()],
 196        );
 197    });
 198
 199    let fs = FakeFs::new(cx.executor());
 200    fs.insert_tree(
 201        "/root",
 202        json!({
 203            "1.txt": "Hello!\nHow\nBye\n",
 204            "2.txt": "Hola!\nComo\nAdios\n"
 205        }),
 206    )
 207    .await;
 208    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 209
 210    let app_state = cx.update(|cx| {
 211        let app_state = AppState::test(cx);
 212        AppState::set_global(app_state.clone(), cx);
 213        app_state
 214    });
 215
 216    let multi_workspace =
 217        cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
 218    let workspace = multi_workspace
 219        .read_with(cx, |multi_workspace, _| multi_workspace.workspace().clone())
 220        .unwrap();
 221    cx.update(|cx| {
 222        AppState::set_global(workspace.read(cx).app_state().clone(), cx);
 223    });
 224    let _ = app_state;
 225
 226    let buffer1 = project
 227        .update(cx, |project, cx| {
 228            let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
 229            project.set_active_path(Some(path.clone()), cx);
 230            project.open_buffer(path, cx)
 231        })
 232        .await
 233        .unwrap();
 234    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
 235    let position = snapshot1.anchor_before(language::Point::new(1, 3));
 236
 237    ep_store.update(cx, |ep_store, cx| {
 238        ep_store.register_project(&project, cx);
 239        ep_store.register_buffer(&buffer1, &project, cx);
 240        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx);
 241    });
 242
 243    let (request, respond_tx) = requests.predict.next().await.unwrap();
 244    respond_tx
 245        .send(model_response(
 246            &request,
 247            indoc! {r"
 248                --- a/root/1.txt
 249                +++ b/root/1.txt
 250                @@ ... @@
 251                 Hello!
 252                -How
 253                +How are you?
 254                 Bye
 255            "},
 256        ))
 257        .unwrap();
 258    cx.run_until_parked();
 259
 260    ep_store.update(cx, |ep_store, cx| {
 261        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
 262    });
 263
 264    let _ = multi_workspace.update(cx, |multi_workspace, window, cx| {
 265        multi_workspace.workspace().update(cx, |workspace, cx| {
 266            workspace.start_following(CollaboratorId::Agent, window, cx);
 267        });
 268    });
 269    cx.run_until_parked();
 270
 271    let diagnostic = lsp::Diagnostic {
 272        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 273        severity: Some(lsp::DiagnosticSeverity::ERROR),
 274        message: "Sentence is incomplete".to_string(),
 275        ..Default::default()
 276    };
 277
 278    project.update(cx, |project, cx| {
 279        project.lsp_store().update(cx, |lsp_store, cx| {
 280            lsp_store
 281                .update_diagnostics(
 282                    LanguageServerId(0),
 283                    lsp::PublishDiagnosticsParams {
 284                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 285                        diagnostics: vec![diagnostic.clone()],
 286                        version: None,
 287                    },
 288                    None,
 289                    language::DiagnosticSourceKind::Pushed,
 290                    &[],
 291                    cx,
 292                )
 293                .unwrap();
 294        });
 295    });
 296
 297    cx.run_until_parked();
 298    assert_no_predict_request_ready(&mut requests.predict);
 299
 300    let _ = multi_workspace.update(cx, |multi_workspace, window, cx| {
 301        multi_workspace.workspace().update(cx, |workspace, cx| {
 302            workspace.unfollow(CollaboratorId::Agent, window, cx);
 303        });
 304    });
 305    cx.run_until_parked();
 306
 307    project.update(cx, |project, cx| {
 308        project.lsp_store().update(cx, |lsp_store, cx| {
 309            lsp_store
 310                .update_diagnostics(
 311                    LanguageServerId(0),
 312                    lsp::PublishDiagnosticsParams {
 313                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 314                        diagnostics: vec![diagnostic],
 315                        version: None,
 316                    },
 317                    None,
 318                    language::DiagnosticSourceKind::Pushed,
 319                    &[],
 320                    cx,
 321                )
 322                .unwrap();
 323        });
 324    });
 325
 326    let (request, respond_tx) = requests.predict.next().await.unwrap();
 327    respond_tx
 328        .send(model_response(
 329            &request,
 330            indoc! {r#"
 331                --- a/root/2.txt
 332                +++ b/root/2.txt
 333                @@ ... @@
 334                 Hola!
 335                -Como
 336                +Como estas?
 337                 Adios
 338            "#},
 339        ))
 340        .unwrap();
 341    cx.run_until_parked();
 342
 343    ep_store.update(cx, |ep_store, cx| {
 344        let prediction = ep_store
 345            .prediction_at(&buffer1, None, &project, cx)
 346            .unwrap();
 347        assert_matches!(
 348            prediction,
 349            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 350        );
 351    });
 352}
 353
 354#[gpui::test]
 355async fn test_simple_request(cx: &mut TestAppContext) {
 356    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 357    let fs = FakeFs::new(cx.executor());
 358    fs.insert_tree(
 359        "/root",
 360        json!({
 361            "foo.md":  "Hello!\nHow\nBye\n"
 362        }),
 363    )
 364    .await;
 365    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 366
 367    let buffer = project
 368        .update(cx, |project, cx| {
 369            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 370            project.open_buffer(path, cx)
 371        })
 372        .await
 373        .unwrap();
 374    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 375    let position = snapshot.anchor_before(language::Point::new(1, 3));
 376
 377    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 378        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 379    });
 380
 381    let (request, respond_tx) = requests.predict.next().await.unwrap();
 382
 383    // TODO Put back when we have a structured request again
 384    // assert_eq!(
 385    //     request.excerpt_path.as_ref(),
 386    //     Path::new(path!("root/foo.md"))
 387    // );
 388    // assert_eq!(
 389    //     request.cursor_point,
 390    //     Point {
 391    //         line: Line(1),
 392    //         column: 3
 393    //     }
 394    // );
 395
 396    respond_tx
 397        .send(model_response(
 398            &request,
 399            indoc! { r"
 400                --- a/root/foo.md
 401                +++ b/root/foo.md
 402                @@ ... @@
 403                 Hello!
 404                -How
 405                +How are you?
 406                 Bye
 407            "},
 408        ))
 409        .unwrap();
 410
 411    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 412
 413    assert_eq!(prediction.edits.len(), 1);
 414    assert_eq!(
 415        prediction.edits[0].0.to_point(&snapshot).start,
 416        language::Point::new(1, 3)
 417    );
 418    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 419}
 420
 421#[gpui::test]
 422async fn test_request_events(cx: &mut TestAppContext) {
 423    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 424    let fs = FakeFs::new(cx.executor());
 425    fs.insert_tree(
 426        "/root",
 427        json!({
 428            "foo.md": "Hello!\n\nBye\n"
 429        }),
 430    )
 431    .await;
 432    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 433
 434    let buffer = project
 435        .update(cx, |project, cx| {
 436            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 437            project.open_buffer(path, cx)
 438        })
 439        .await
 440        .unwrap();
 441
 442    ep_store.update(cx, |ep_store, cx| {
 443        ep_store.register_buffer(&buffer, &project, cx);
 444    });
 445
 446    buffer.update(cx, |buffer, cx| {
 447        buffer.edit(vec![(7..7, "How")], None, cx);
 448    });
 449
 450    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 451    let position = snapshot.anchor_before(language::Point::new(1, 3));
 452
 453    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 454        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 455    });
 456
 457    let (request, respond_tx) = requests.predict.next().await.unwrap();
 458
 459    let prompt = prompt_from_request(&request);
 460    assert!(
 461        prompt.contains(indoc! {"
 462        --- a/root/foo.md
 463        +++ b/root/foo.md
 464        @@ -1,3 +1,3 @@
 465         Hello!
 466        -
 467        +How
 468         Bye
 469    "}),
 470        "{prompt}"
 471    );
 472
 473    respond_tx
 474        .send(model_response(
 475            &request,
 476            indoc! {r#"
 477                --- a/root/foo.md
 478                +++ b/root/foo.md
 479                @@ ... @@
 480                 Hello!
 481                -How
 482                +How are you?
 483                 Bye
 484        "#},
 485        ))
 486        .unwrap();
 487
 488    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 489
 490    assert_eq!(prediction.edits.len(), 1);
 491    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 492}
 493
 494#[gpui::test]
 495async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 496    let (ep_store, _requests) = init_test_with_fake_client(cx);
 497    let fs = FakeFs::new(cx.executor());
 498    fs.insert_tree(
 499        "/root",
 500        json!({
 501            "foo.md": "Hello!\n\nBye\n"
 502        }),
 503    )
 504    .await;
 505    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 506
 507    let buffer = project
 508        .update(cx, |project, cx| {
 509            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 510            project.open_buffer(path, cx)
 511        })
 512        .await
 513        .unwrap();
 514
 515    ep_store.update(cx, |ep_store, cx| {
 516        ep_store.register_buffer(&buffer, &project, cx);
 517    });
 518
 519    // First burst: insert "How"
 520    buffer.update(cx, |buffer, cx| {
 521        buffer.edit(vec![(7..7, "How")], None, cx);
 522    });
 523
 524    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 525    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 526    cx.run_until_parked();
 527
 528    // Second burst: append " are you?" immediately after "How" on the same line.
 529    //
 530    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 531    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 532    buffer.update(cx, |buffer, cx| {
 533        buffer.edit(vec![(10..10, " are you?")], None, cx);
 534    });
 535
 536    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 537    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 538    buffer.update(cx, |buffer, cx| {
 539        buffer.edit(vec![(19..19, "!")], None, cx);
 540    });
 541
 542    // With time-based splitting, there are two distinct events.
 543    let events = ep_store.update(cx, |ep_store, cx| {
 544        ep_store.edit_history_for_project(&project, cx)
 545    });
 546    assert_eq!(events.len(), 2);
 547
 548    let first_total_edit_range = buffer.read_with(cx, |buffer, _| {
 549        events[0].total_edit_range.to_point(&buffer.snapshot())
 550    });
 551    assert_eq!(first_total_edit_range, Point::new(1, 0)..Point::new(1, 3));
 552
 553    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 554    assert_eq!(
 555        diff.as_str(),
 556        indoc! {"
 557            @@ -1,3 +1,3 @@
 558             Hello!
 559            -
 560            +How
 561             Bye
 562        "}
 563    );
 564
 565    let second_total_edit_range = buffer.read_with(cx, |buffer, _| {
 566        events[1].total_edit_range.to_point(&buffer.snapshot())
 567    });
 568    assert_eq!(second_total_edit_range, Point::new(1, 3)..Point::new(1, 13));
 569
 570    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 571    assert_eq!(
 572        diff.as_str(),
 573        indoc! {"
 574            @@ -1,3 +1,3 @@
 575             Hello!
 576            -How
 577            +How are you?!
 578             Bye
 579        "}
 580    );
 581}
 582
 583#[gpui::test]
 584async fn test_predicted_edits_are_separated_in_edit_history(cx: &mut TestAppContext) {
 585    let (ep_store, _requests) = init_test_with_fake_client(cx);
 586    let fs = FakeFs::new(cx.executor());
 587
 588    // Create a file with 30 lines to test line-based coalescing
 589    let content = (1..=30)
 590        .map(|i| format!("Line {}\n", i))
 591        .collect::<String>();
 592    fs.insert_tree(
 593        "/root",
 594        json!({
 595            "foo.md": content
 596        }),
 597    )
 598    .await;
 599    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 600
 601    let buffer = project
 602        .update(cx, |project, cx| {
 603            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 604            project.open_buffer(path, cx)
 605        })
 606        .await
 607        .unwrap();
 608
 609    ep_store.update(cx, |ep_store, cx| {
 610        ep_store.register_buffer(&buffer, &project, cx);
 611    });
 612
 613    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 614    buffer.update(cx, |buffer, cx| {
 615        let start = Point::new(10, 0).to_offset(buffer);
 616        let end = Point::new(13, 0).to_offset(buffer);
 617        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 618    });
 619
 620    let events = ep_store.update(cx, |ep_store, cx| {
 621        ep_store.edit_history_for_project(&project, cx)
 622    });
 623    assert_eq!(
 624        render_events(&events),
 625        indoc! {"
 626            @@ -8,9 +8,8 @@
 627             Line 8
 628             Line 9
 629             Line 10
 630            -Line 11
 631            -Line 12
 632            -Line 13
 633            +Middle A
 634            +Middle B
 635             Line 14
 636             Line 15
 637             Line 16
 638        "},
 639        "After first edit"
 640    );
 641
 642    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 643    // This tests that coalescing considers the START of the existing range
 644    buffer.update(cx, |buffer, cx| {
 645        let offset = Point::new(5, 0).to_offset(buffer);
 646        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 647    });
 648
 649    let events = ep_store.update(cx, |ep_store, cx| {
 650        ep_store.edit_history_for_project(&project, cx)
 651    });
 652    assert_eq!(
 653        render_events(&events),
 654        indoc! {"
 655            @@ -3,14 +3,14 @@
 656             Line 3
 657             Line 4
 658             Line 5
 659            +Above
 660             Line 6
 661             Line 7
 662             Line 8
 663             Line 9
 664             Line 10
 665            -Line 11
 666            -Line 12
 667            -Line 13
 668            +Middle A
 669            +Middle B
 670             Line 14
 671             Line 15
 672             Line 16
 673        "},
 674        "After inserting above (should coalesce)"
 675    );
 676
 677    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 678    // This tests that coalescing considers the END of the existing range
 679    buffer.update(cx, |buffer, cx| {
 680        let offset = Point::new(14, 0).to_offset(buffer);
 681        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 682    });
 683
 684    let events = ep_store.update(cx, |ep_store, cx| {
 685        ep_store.edit_history_for_project(&project, cx)
 686    });
 687    assert_eq!(
 688        render_events(&events),
 689        indoc! {"
 690            @@ -3,15 +3,16 @@
 691             Line 3
 692             Line 4
 693             Line 5
 694            +Above
 695             Line 6
 696             Line 7
 697             Line 8
 698             Line 9
 699             Line 10
 700            -Line 11
 701            -Line 12
 702            -Line 13
 703            +Middle A
 704            +Middle B
 705             Line 14
 706            +Below
 707             Line 15
 708             Line 16
 709             Line 17
 710        "},
 711        "After inserting below (should coalesce)"
 712    );
 713
 714    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 715    // This should NOT coalesce - creates a new event
 716    buffer.update(cx, |buffer, cx| {
 717        let offset = Point::new(25, 0).to_offset(buffer);
 718        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 719    });
 720
 721    let events = ep_store.update(cx, |ep_store, cx| {
 722        ep_store.edit_history_for_project(&project, cx)
 723    });
 724    assert_eq!(
 725        render_events(&events),
 726        indoc! {"
 727            @@ -3,15 +3,16 @@
 728             Line 3
 729             Line 4
 730             Line 5
 731            +Above
 732             Line 6
 733             Line 7
 734             Line 8
 735             Line 9
 736             Line 10
 737            -Line 11
 738            -Line 12
 739            -Line 13
 740            +Middle A
 741            +Middle B
 742             Line 14
 743            +Below
 744             Line 15
 745             Line 16
 746             Line 17
 747
 748            ---
 749            @@ -23,6 +23,7 @@
 750             Line 22
 751             Line 23
 752             Line 24
 753            +Far below
 754             Line 25
 755             Line 26
 756             Line 27
 757        "},
 758        "After inserting far below (should NOT coalesce)"
 759    );
 760}
 761
 762fn render_events(events: &[StoredEvent]) -> String {
 763    events
 764        .iter()
 765        .map(|e| {
 766            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 767            diff.as_str()
 768        })
 769        .collect::<Vec<_>>()
 770        .join("\n---\n")
 771}
 772
 773fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
 774    events
 775        .iter()
 776        .map(|e| {
 777            let zeta_prompt::Event::BufferChange {
 778                diff, predicted, ..
 779            } = e.event.as_ref();
 780            let prefix = if *predicted { "predicted" } else { "manual" };
 781            format!("{}\n{}", prefix, diff)
 782        })
 783        .collect()
 784}
 785
 786fn make_collaborator_replica(
 787    buffer: &Entity<Buffer>,
 788    cx: &mut TestAppContext,
 789) -> (Entity<Buffer>, clock::Global) {
 790    let (state, version) =
 791        buffer.read_with(cx, |buffer, _cx| (buffer.to_proto(_cx), buffer.version()));
 792    let collaborator = cx.new(|_cx| {
 793        Buffer::from_proto(ReplicaId::new(1), Capability::ReadWrite, state, None).unwrap()
 794    });
 795    (collaborator, version)
 796}
 797
 798async fn apply_collaborator_edit(
 799    collaborator: &Entity<Buffer>,
 800    buffer: &Entity<Buffer>,
 801    since_version: &mut clock::Global,
 802    edit_range: Range<usize>,
 803    new_text: &str,
 804    cx: &mut TestAppContext,
 805) {
 806    collaborator.update(cx, |collaborator, cx| {
 807        collaborator.edit([(edit_range, new_text)], None, cx);
 808    });
 809
 810    let serialize_task = collaborator.read_with(cx, |collaborator, cx| {
 811        collaborator.serialize_ops(Some(since_version.clone()), cx)
 812    });
 813    let ops = serialize_task.await;
 814    *since_version = collaborator.read_with(cx, |collaborator, _cx| collaborator.version());
 815
 816    buffer.update(cx, |buffer, cx| {
 817        buffer.apply_ops(
 818            ops.into_iter()
 819                .map(|op| language::proto::deserialize_operation(op).unwrap()),
 820            cx,
 821        );
 822    });
 823}
 824
 825#[gpui::test]
 826async fn test_nearby_collaborator_edits_are_kept_in_history(cx: &mut TestAppContext) {
 827    let (ep_store, _requests) = init_test_with_fake_client(cx);
 828    let fs = FakeFs::new(cx.executor());
 829    fs.insert_tree(
 830        "/root",
 831        json!({
 832            "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
 833        }),
 834    )
 835    .await;
 836    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 837
 838    let buffer = project
 839        .update(cx, |project, cx| {
 840            let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
 841            project.set_active_path(Some(path.clone()), cx);
 842            project.open_buffer(path, cx)
 843        })
 844        .await
 845        .unwrap();
 846
 847    let cursor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(Point::new(1, 0)));
 848
 849    ep_store.update(cx, |ep_store, cx| {
 850        ep_store.register_buffer(&buffer, &project, cx);
 851        let _ = ep_store.prediction_at(&buffer, Some(cursor), &project, cx);
 852    });
 853
 854    buffer.update(cx, |buffer, cx| {
 855        buffer.edit(vec![(0..6, "LOCAL ZERO")], None, cx);
 856    });
 857
 858    let (collaborator, mut collaborator_version) = make_collaborator_replica(&buffer, cx);
 859
 860    let (line_one_start, line_one_len) = collaborator.read_with(cx, |buffer, _cx| {
 861        (Point::new(1, 0).to_offset(buffer), buffer.line_len(1))
 862    });
 863
 864    apply_collaborator_edit(
 865        &collaborator,
 866        &buffer,
 867        &mut collaborator_version,
 868        line_one_start..line_one_start + line_one_len as usize,
 869        "REMOTE ONE",
 870        cx,
 871    )
 872    .await;
 873
 874    let events = ep_store.update(cx, |ep_store, cx| {
 875        ep_store.edit_history_for_project(&project, cx)
 876    });
 877
 878    assert_eq!(
 879        render_events_with_predicted(&events),
 880        vec![indoc! {"
 881            manual
 882            @@ -1,5 +1,5 @@
 883            -line 0
 884            -line 1
 885            +LOCAL ZERO
 886            +REMOTE ONE
 887             line 2
 888             line 3
 889             line 4
 890        "}]
 891    );
 892}
 893
 894#[gpui::test]
 895async fn test_distant_collaborator_edits_are_omitted_from_history(cx: &mut TestAppContext) {
 896    let (ep_store, _requests) = init_test_with_fake_client(cx);
 897    let fs = FakeFs::new(cx.executor());
 898    fs.insert_tree(
 899        "/root",
 900        json!({
 901            "foo.rs": (0..1000)
 902                .map(|i| format!("line {i}\n"))
 903                .collect::<String>()
 904        }),
 905    )
 906    .await;
 907    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 908
 909    let buffer = project
 910        .update(cx, |project, cx| {
 911            let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
 912            project.set_active_path(Some(path.clone()), cx);
 913            project.open_buffer(path, cx)
 914        })
 915        .await
 916        .unwrap();
 917
 918    let cursor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(Point::new(1, 0)));
 919
 920    ep_store.update(cx, |ep_store, cx| {
 921        ep_store.register_buffer(&buffer, &project, cx);
 922        let _ = ep_store.prediction_at(&buffer, Some(cursor), &project, cx);
 923    });
 924
 925    buffer.update(cx, |buffer, cx| {
 926        buffer.edit(vec![(0..6, "LOCAL ZERO")], None, cx);
 927    });
 928
 929    let (collaborator, mut collaborator_version) = make_collaborator_replica(&buffer, cx);
 930
 931    let far_line_start = buffer.read_with(cx, |buffer, _cx| Point::new(900, 0).to_offset(buffer));
 932
 933    apply_collaborator_edit(
 934        &collaborator,
 935        &buffer,
 936        &mut collaborator_version,
 937        far_line_start..far_line_start + 7,
 938        "REMOTE FAR",
 939        cx,
 940    )
 941    .await;
 942
 943    let events = ep_store.update(cx, |ep_store, cx| {
 944        ep_store.edit_history_for_project(&project, cx)
 945    });
 946
 947    assert_eq!(
 948        render_events_with_predicted(&events),
 949        vec![indoc! {"
 950            manual
 951            @@ -1,4 +1,4 @@
 952            -line 0
 953            +LOCAL ZERO
 954             line 1
 955             line 2
 956             line 3
 957        "}]
 958    );
 959}
 960
 961#[gpui::test]
 962async fn test_irrelevant_collaborator_edits_in_different_files_are_omitted_from_history(
 963    cx: &mut TestAppContext,
 964) {
 965    let (ep_store, _requests) = init_test_with_fake_client(cx);
 966    let fs = FakeFs::new(cx.executor());
 967    fs.insert_tree(
 968        "/root",
 969        json!({
 970            "foo.rs": "line 0\nline 1\nline 2\nline 3\n",
 971            "bar.rs": "line 0\nline 1\nline 2\nline 3\n"
 972        }),
 973    )
 974    .await;
 975    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 976
 977    let foo_buffer = project
 978        .update(cx, |project, cx| {
 979            let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
 980            project.set_active_path(Some(path.clone()), cx);
 981            project.open_buffer(path, cx)
 982        })
 983        .await
 984        .unwrap();
 985    let bar_buffer = project
 986        .update(cx, |project, cx| {
 987            let path = project.find_project_path(path!("root/bar.rs"), cx).unwrap();
 988            project.open_buffer(path, cx)
 989        })
 990        .await
 991        .unwrap();
 992
 993    let foo_cursor = foo_buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(Point::new(1, 0)));
 994
 995    ep_store.update(cx, |ep_store, cx| {
 996        ep_store.register_buffer(&foo_buffer, &project, cx);
 997        ep_store.register_buffer(&bar_buffer, &project, cx);
 998        let _ = ep_store.prediction_at(&foo_buffer, Some(foo_cursor), &project, cx);
 999    });
1000
1001    let (bar_collaborator, mut bar_version) = make_collaborator_replica(&bar_buffer, cx);
1002
1003    apply_collaborator_edit(
1004        &bar_collaborator,
1005        &bar_buffer,
1006        &mut bar_version,
1007        0..6,
1008        "REMOTE BAR",
1009        cx,
1010    )
1011    .await;
1012
1013    let events = ep_store.update(cx, |ep_store, cx| {
1014        ep_store.edit_history_for_project(&project, cx)
1015    });
1016
1017    assert!(events.is_empty());
1018}
1019
1020#[gpui::test]
1021async fn test_large_edits_are_omitted_from_history(cx: &mut TestAppContext) {
1022    let (ep_store, _requests) = init_test_with_fake_client(cx);
1023    let fs = FakeFs::new(cx.executor());
1024    fs.insert_tree(
1025        "/root",
1026        json!({
1027            "foo.rs": (0..20)
1028                .map(|i| format!("line {i}\n"))
1029                .collect::<String>()
1030        }),
1031    )
1032    .await;
1033    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1034
1035    let buffer = project
1036        .update(cx, |project, cx| {
1037            let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
1038            project.set_active_path(Some(path.clone()), cx);
1039            project.open_buffer(path, cx)
1040        })
1041        .await
1042        .unwrap();
1043
1044    let cursor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(Point::new(1, 0)));
1045
1046    ep_store.update(cx, |ep_store, cx| {
1047        ep_store.register_buffer(&buffer, &project, cx);
1048        let _ = ep_store.prediction_at(&buffer, Some(cursor), &project, cx);
1049    });
1050
1051    buffer.update(cx, |buffer, cx| {
1052        buffer.edit(vec![(0..6, "LOCAL ZERO")], None, cx);
1053    });
1054
1055    let (collaborator, mut collaborator_version) = make_collaborator_replica(&buffer, cx);
1056
1057    let (line_three_start, line_three_len) = collaborator.read_with(cx, |buffer, _cx| {
1058        (Point::new(3, 0).to_offset(buffer), buffer.line_len(3))
1059    });
1060    let large_edit = "X".repeat(EDIT_HISTORY_DIFF_SIZE_LIMIT + 1);
1061
1062    apply_collaborator_edit(
1063        &collaborator,
1064        &buffer,
1065        &mut collaborator_version,
1066        line_three_start..line_three_start + line_three_len as usize,
1067        &large_edit,
1068        cx,
1069    )
1070    .await;
1071
1072    buffer.update(cx, |buffer, cx| {
1073        let line_seven_start = Point::new(7, 0).to_offset(buffer);
1074        let line_seven_end = Point::new(7, 6).to_offset(buffer);
1075        buffer.edit(
1076            vec![(line_seven_start..line_seven_end, "LOCAL SEVEN")],
1077            None,
1078            cx,
1079        );
1080    });
1081
1082    let events = ep_store.update(cx, |ep_store, cx| {
1083        ep_store.edit_history_for_project(&project, cx)
1084    });
1085
1086    let rendered_events = render_events_with_predicted(&events);
1087
1088    assert_eq!(rendered_events.len(), 2);
1089    assert!(rendered_events[0].contains("+LOCAL ZERO"));
1090    assert!(!rendered_events[0].contains(&large_edit));
1091    assert!(rendered_events[1].contains("+LOCAL SEVEN"));
1092    assert!(!rendered_events[1].contains(&large_edit));
1093}
1094
1095#[gpui::test]
1096async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
1097    let (ep_store, _requests) = init_test_with_fake_client(cx);
1098    let fs = FakeFs::new(cx.executor());
1099    fs.insert_tree(
1100        "/root",
1101        json!({
1102            "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
1103        }),
1104    )
1105    .await;
1106    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1107
1108    let buffer = project
1109        .update(cx, |project, cx| {
1110            let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
1111            project.open_buffer(path, cx)
1112        })
1113        .await
1114        .unwrap();
1115
1116    ep_store.update(cx, |ep_store, cx| {
1117        ep_store.register_buffer(&buffer, &project, cx);
1118    });
1119
1120    // Case 1: Manual edits have `predicted` set to false.
1121    buffer.update(cx, |buffer, cx| {
1122        buffer.edit(vec![(0..6, "LINE ZERO")], None, cx);
1123    });
1124
1125    let events = ep_store.update(cx, |ep_store, cx| {
1126        ep_store.edit_history_for_project(&project, cx)
1127    });
1128
1129    assert_eq!(
1130        render_events_with_predicted(&events),
1131        vec![indoc! {"
1132            manual
1133            @@ -1,4 +1,4 @@
1134            -line 0
1135            +LINE ZERO
1136             line 1
1137             line 2
1138             line 3
1139        "}]
1140    );
1141
1142    // Case 2: Multiple successive manual edits near each other are merged into one
1143    // event with `predicted` set to false.
1144    buffer.update(cx, |buffer, cx| {
1145        let offset = Point::new(1, 0).to_offset(buffer);
1146        let end = Point::new(1, 6).to_offset(buffer);
1147        buffer.edit(vec![(offset..end, "LINE ONE")], None, cx);
1148    });
1149
1150    let events = ep_store.update(cx, |ep_store, cx| {
1151        ep_store.edit_history_for_project(&project, cx)
1152    });
1153    assert_eq!(
1154        render_events_with_predicted(&events),
1155        vec![indoc! {"
1156            manual
1157            @@ -1,5 +1,5 @@
1158            -line 0
1159            -line 1
1160            +LINE ZERO
1161            +LINE ONE
1162             line 2
1163             line 3
1164             line 4
1165        "}]
1166    );
1167
1168    // Case 3: Accepted predictions have `predicted` set to true.
1169    // Case 5: A manual edit that follows a predicted edit is not merged with the
1170    // predicted edit, even if it is nearby.
1171    ep_store.update(cx, |ep_store, cx| {
1172        buffer.update(cx, |buffer, cx| {
1173            let offset = Point::new(2, 0).to_offset(buffer);
1174            let end = Point::new(2, 6).to_offset(buffer);
1175            buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
1176        });
1177        ep_store.report_changes_for_buffer(&buffer, &project, true, true, cx);
1178    });
1179
1180    let events = ep_store.update(cx, |ep_store, cx| {
1181        ep_store.edit_history_for_project(&project, cx)
1182    });
1183    assert_eq!(
1184        render_events_with_predicted(&events),
1185        vec![
1186            indoc! {"
1187                manual
1188                @@ -1,5 +1,5 @@
1189                -line 0
1190                -line 1
1191                +LINE ZERO
1192                +LINE ONE
1193                 line 2
1194                 line 3
1195                 line 4
1196            "},
1197            indoc! {"
1198                predicted
1199                @@ -1,6 +1,6 @@
1200                 LINE ZERO
1201                 LINE ONE
1202                -line 2
1203                +LINE TWO
1204                 line 3
1205                 line 4
1206                 line 5
1207            "}
1208        ]
1209    );
1210
1211    // Case 4: Multiple successive accepted predictions near each other are merged
1212    // into one event with `predicted` set to true.
1213    ep_store.update(cx, |ep_store, cx| {
1214        buffer.update(cx, |buffer, cx| {
1215            let offset = Point::new(3, 0).to_offset(buffer);
1216            let end = Point::new(3, 6).to_offset(buffer);
1217            buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
1218        });
1219        ep_store.report_changes_for_buffer(&buffer, &project, true, true, cx);
1220    });
1221
1222    let events = ep_store.update(cx, |ep_store, cx| {
1223        ep_store.edit_history_for_project(&project, cx)
1224    });
1225    assert_eq!(
1226        render_events_with_predicted(&events),
1227        vec![
1228            indoc! {"
1229                manual
1230                @@ -1,5 +1,5 @@
1231                -line 0
1232                -line 1
1233                +LINE ZERO
1234                +LINE ONE
1235                 line 2
1236                 line 3
1237                 line 4
1238            "},
1239            indoc! {"
1240                predicted
1241                @@ -1,7 +1,7 @@
1242                 LINE ZERO
1243                 LINE ONE
1244                -line 2
1245                -line 3
1246                +LINE TWO
1247                +LINE THREE
1248                 line 4
1249                 line 5
1250                 line 6
1251            "}
1252        ]
1253    );
1254
1255    // Case 5 (continued): A manual edit that follows a predicted edit is not merged
1256    // with the predicted edit, even if it is nearby.
1257    buffer.update(cx, |buffer, cx| {
1258        let offset = Point::new(4, 0).to_offset(buffer);
1259        let end = Point::new(4, 6).to_offset(buffer);
1260        buffer.edit(vec![(offset..end, "LINE FOUR")], None, cx);
1261    });
1262
1263    let events = ep_store.update(cx, |ep_store, cx| {
1264        ep_store.edit_history_for_project(&project, cx)
1265    });
1266    assert_eq!(
1267        render_events_with_predicted(&events),
1268        vec![
1269            indoc! {"
1270                manual
1271                @@ -1,5 +1,5 @@
1272                -line 0
1273                -line 1
1274                +LINE ZERO
1275                +LINE ONE
1276                 line 2
1277                 line 3
1278                 line 4
1279            "},
1280            indoc! {"
1281                predicted
1282                @@ -1,7 +1,7 @@
1283                 LINE ZERO
1284                 LINE ONE
1285                -line 2
1286                -line 3
1287                +LINE TWO
1288                +LINE THREE
1289                 line 4
1290                 line 5
1291                 line 6
1292            "},
1293            indoc! {"
1294                manual
1295                @@ -2,7 +2,7 @@
1296                 LINE ONE
1297                 LINE TWO
1298                 LINE THREE
1299                -line 4
1300                +LINE FOUR
1301                 line 5
1302                 line 6
1303                 line 7
1304            "}
1305        ]
1306    );
1307
1308    // Case 6: If we then perform a manual edit at a *different* location (more than
1309    // 8 lines away), then the edits at the prior location can be merged with each
1310    // other, even if some are predicted and some are not. `predicted` means all
1311    // constituent edits were predicted.
1312    buffer.update(cx, |buffer, cx| {
1313        let offset = Point::new(14, 0).to_offset(buffer);
1314        let end = Point::new(14, 7).to_offset(buffer);
1315        buffer.edit(vec![(offset..end, "LINE FOURTEEN")], None, cx);
1316    });
1317
1318    let events = ep_store.update(cx, |ep_store, cx| {
1319        ep_store.edit_history_for_project(&project, cx)
1320    });
1321    assert_eq!(
1322        render_events_with_predicted(&events),
1323        vec![
1324            indoc! {"
1325                manual
1326                @@ -1,8 +1,8 @@
1327                -line 0
1328                -line 1
1329                -line 2
1330                -line 3
1331                -line 4
1332                +LINE ZERO
1333                +LINE ONE
1334                +LINE TWO
1335                +LINE THREE
1336                +LINE FOUR
1337                 line 5
1338                 line 6
1339                 line 7
1340            "},
1341            indoc! {"
1342                manual
1343                @@ -12,4 +12,4 @@
1344                 line 11
1345                 line 12
1346                 line 13
1347                -line 14
1348                +LINE FOURTEEN
1349            "}
1350        ]
1351    );
1352}
1353
1354#[gpui::test]
1355async fn test_empty_prediction(cx: &mut TestAppContext) {
1356    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1357    let fs = FakeFs::new(cx.executor());
1358    fs.insert_tree(
1359        "/root",
1360        json!({
1361            "foo.md":  "Hello!\nHow\nBye\n"
1362        }),
1363    )
1364    .await;
1365    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1366
1367    let buffer = project
1368        .update(cx, |project, cx| {
1369            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1370            project.open_buffer(path, cx)
1371        })
1372        .await
1373        .unwrap();
1374    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1375    let position = snapshot.anchor_before(language::Point::new(1, 3));
1376
1377    ep_store.update(cx, |ep_store, cx| {
1378        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1379    });
1380
1381    let (request, respond_tx) = requests.predict.next().await.unwrap();
1382    let mut response = model_response(&request, "");
1383    response.model_version = Some("zeta2:test-empty".to_string());
1384    let id = response.request_id.clone();
1385    respond_tx.send(response).unwrap();
1386
1387    cx.run_until_parked();
1388
1389    ep_store.update(cx, |ep_store, cx| {
1390        assert!(
1391            ep_store
1392                .prediction_at(&buffer, None, &project, cx)
1393                .is_none()
1394        );
1395    });
1396
1397    // prediction is reported as rejected
1398    let (reject_request, _) = requests.reject.next().await.unwrap();
1399
1400    assert_eq!(
1401        &reject_request.rejections,
1402        &[EditPredictionRejection {
1403            request_id: id,
1404            reason: EditPredictionRejectReason::Empty,
1405            was_shown: false,
1406            model_version: Some("zeta2:test-empty".to_string()),
1407            e2e_latency_ms: Some(0),
1408        }]
1409    );
1410}
1411
1412#[gpui::test]
1413async fn test_interpolated_empty(cx: &mut TestAppContext) {
1414    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1415    let fs = FakeFs::new(cx.executor());
1416    fs.insert_tree(
1417        "/root",
1418        json!({
1419            "foo.md":  "Hello!\nHow\nBye\n"
1420        }),
1421    )
1422    .await;
1423    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1424
1425    let buffer = project
1426        .update(cx, |project, cx| {
1427            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1428            project.open_buffer(path, cx)
1429        })
1430        .await
1431        .unwrap();
1432    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1433    let position = snapshot.anchor_before(language::Point::new(1, 3));
1434
1435    ep_store.update(cx, |ep_store, cx| {
1436        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1437    });
1438
1439    let (request, respond_tx) = requests.predict.next().await.unwrap();
1440
1441    buffer.update(cx, |buffer, cx| {
1442        buffer.set_text("Hello!\nHow are you?\nBye", cx);
1443    });
1444
1445    let mut response = model_response(&request, SIMPLE_DIFF);
1446    response.model_version = Some("zeta2:test-interpolated-empty".to_string());
1447    let id = response.request_id.clone();
1448    respond_tx.send(response).unwrap();
1449
1450    cx.run_until_parked();
1451
1452    ep_store.update(cx, |ep_store, cx| {
1453        assert!(
1454            ep_store
1455                .prediction_at(&buffer, None, &project, cx)
1456                .is_none()
1457        );
1458    });
1459
1460    // prediction is reported as rejected
1461    let (reject_request, _) = requests.reject.next().await.unwrap();
1462
1463    assert_eq!(
1464        &reject_request.rejections,
1465        &[EditPredictionRejection {
1466            request_id: id,
1467            reason: EditPredictionRejectReason::InterpolatedEmpty,
1468            was_shown: false,
1469            model_version: Some("zeta2:test-interpolated-empty".to_string()),
1470            e2e_latency_ms: Some(0),
1471        }]
1472    );
1473}
1474
1475const SIMPLE_DIFF: &str = indoc! { r"
1476    --- a/root/foo.md
1477    +++ b/root/foo.md
1478    @@ ... @@
1479     Hello!
1480    -How
1481    +How are you?
1482     Bye
1483"};
1484
1485#[gpui::test]
1486async fn test_replace_current(cx: &mut TestAppContext) {
1487    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1488    let fs = FakeFs::new(cx.executor());
1489    fs.insert_tree(
1490        "/root",
1491        json!({
1492            "foo.md":  "Hello!\nHow\nBye\n"
1493        }),
1494    )
1495    .await;
1496    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1497
1498    let buffer = project
1499        .update(cx, |project, cx| {
1500            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1501            project.open_buffer(path, cx)
1502        })
1503        .await
1504        .unwrap();
1505    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1506    let position = snapshot.anchor_before(language::Point::new(1, 3));
1507
1508    ep_store.update(cx, |ep_store, cx| {
1509        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1510    });
1511
1512    let (request, respond_tx) = requests.predict.next().await.unwrap();
1513    let first_response = model_response(&request, SIMPLE_DIFF);
1514    let first_id = first_response.request_id.clone();
1515    respond_tx.send(first_response).unwrap();
1516
1517    cx.run_until_parked();
1518
1519    ep_store.update(cx, |ep_store, cx| {
1520        assert_eq!(
1521            ep_store
1522                .prediction_at(&buffer, None, &project, cx)
1523                .unwrap()
1524                .id
1525                .0,
1526            first_id
1527        );
1528    });
1529
1530    // a second request is triggered
1531    ep_store.update(cx, |ep_store, cx| {
1532        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1533    });
1534
1535    let (request, respond_tx) = requests.predict.next().await.unwrap();
1536    let second_response = model_response(&request, SIMPLE_DIFF);
1537    let second_id = second_response.request_id.clone();
1538    respond_tx.send(second_response).unwrap();
1539
1540    cx.run_until_parked();
1541
1542    ep_store.update(cx, |ep_store, cx| {
1543        // second replaces first
1544        assert_eq!(
1545            ep_store
1546                .prediction_at(&buffer, None, &project, cx)
1547                .unwrap()
1548                .id
1549                .0,
1550            second_id
1551        );
1552    });
1553
1554    // first is reported as replaced
1555    let (reject_request, _) = requests.reject.next().await.unwrap();
1556
1557    assert_eq!(
1558        &reject_request.rejections,
1559        &[EditPredictionRejection {
1560            request_id: first_id,
1561            reason: EditPredictionRejectReason::Replaced,
1562            was_shown: false,
1563            model_version: None,
1564            e2e_latency_ms: Some(0),
1565        }]
1566    );
1567}
1568
1569#[gpui::test]
1570async fn test_current_preferred(cx: &mut TestAppContext) {
1571    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1572    let fs = FakeFs::new(cx.executor());
1573    fs.insert_tree(
1574        "/root",
1575        json!({
1576            "foo.md":  "Hello!\nHow\nBye\n"
1577        }),
1578    )
1579    .await;
1580    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1581
1582    let buffer = project
1583        .update(cx, |project, cx| {
1584            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1585            project.open_buffer(path, cx)
1586        })
1587        .await
1588        .unwrap();
1589    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1590    let position = snapshot.anchor_before(language::Point::new(1, 3));
1591
1592    ep_store.update(cx, |ep_store, cx| {
1593        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1594    });
1595
1596    let (request, respond_tx) = requests.predict.next().await.unwrap();
1597    let first_response = model_response(&request, SIMPLE_DIFF);
1598    let first_id = first_response.request_id.clone();
1599    respond_tx.send(first_response).unwrap();
1600
1601    cx.run_until_parked();
1602
1603    ep_store.update(cx, |ep_store, cx| {
1604        assert_eq!(
1605            ep_store
1606                .prediction_at(&buffer, None, &project, cx)
1607                .unwrap()
1608                .id
1609                .0,
1610            first_id
1611        );
1612    });
1613
1614    // a second request is triggered
1615    ep_store.update(cx, |ep_store, cx| {
1616        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1617    });
1618
1619    let (request, respond_tx) = requests.predict.next().await.unwrap();
1620    // worse than current prediction
1621    let mut second_response = model_response(
1622        &request,
1623        indoc! { r"
1624            --- a/root/foo.md
1625            +++ b/root/foo.md
1626            @@ ... @@
1627             Hello!
1628            -How
1629            +How are
1630             Bye
1631        "},
1632    );
1633    second_response.model_version = Some("zeta2:test-current-preferred".to_string());
1634    let second_id = second_response.request_id.clone();
1635    respond_tx.send(second_response).unwrap();
1636
1637    cx.run_until_parked();
1638
1639    ep_store.update(cx, |ep_store, cx| {
1640        // first is preferred over second
1641        assert_eq!(
1642            ep_store
1643                .prediction_at(&buffer, None, &project, cx)
1644                .unwrap()
1645                .id
1646                .0,
1647            first_id
1648        );
1649    });
1650
1651    // second is reported as rejected
1652    let (reject_request, _) = requests.reject.next().await.unwrap();
1653
1654    assert_eq!(
1655        &reject_request.rejections,
1656        &[EditPredictionRejection {
1657            request_id: second_id,
1658            reason: EditPredictionRejectReason::CurrentPreferred,
1659            was_shown: false,
1660            model_version: Some("zeta2:test-current-preferred".to_string()),
1661            e2e_latency_ms: Some(0),
1662        }]
1663    );
1664}
1665
1666#[gpui::test]
1667async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
1668    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1669    let fs = FakeFs::new(cx.executor());
1670    fs.insert_tree(
1671        "/root",
1672        json!({
1673            "foo.md":  "Hello!\nHow\nBye\n"
1674        }),
1675    )
1676    .await;
1677    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1678
1679    let buffer = project
1680        .update(cx, |project, cx| {
1681            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1682            project.open_buffer(path, cx)
1683        })
1684        .await
1685        .unwrap();
1686    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1687    let position = snapshot.anchor_before(language::Point::new(1, 3));
1688
1689    // start two refresh tasks
1690    ep_store.update(cx, |ep_store, cx| {
1691        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1692    });
1693
1694    let (request1, respond_first) = requests.predict.next().await.unwrap();
1695
1696    ep_store.update(cx, |ep_store, cx| {
1697        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1698    });
1699
1700    let (request, respond_second) = requests.predict.next().await.unwrap();
1701
1702    // wait for throttle
1703    cx.run_until_parked();
1704
1705    // second responds first
1706    let second_response = model_response(&request, SIMPLE_DIFF);
1707    let second_id = second_response.request_id.clone();
1708    respond_second.send(second_response).unwrap();
1709
1710    cx.run_until_parked();
1711
1712    ep_store.update(cx, |ep_store, cx| {
1713        // current prediction is second
1714        assert_eq!(
1715            ep_store
1716                .prediction_at(&buffer, None, &project, cx)
1717                .unwrap()
1718                .id
1719                .0,
1720            second_id
1721        );
1722    });
1723
1724    let mut first_response = model_response(&request1, SIMPLE_DIFF);
1725    first_response.model_version = Some("zeta2:test-canceled".to_string());
1726    let first_id = first_response.request_id.clone();
1727    respond_first.send(first_response).unwrap();
1728
1729    cx.run_until_parked();
1730
1731    ep_store.update(cx, |ep_store, cx| {
1732        // current prediction is still second, since first was cancelled
1733        assert_eq!(
1734            ep_store
1735                .prediction_at(&buffer, None, &project, cx)
1736                .unwrap()
1737                .id
1738                .0,
1739            second_id
1740        );
1741    });
1742
1743    // first is reported as rejected
1744    let (reject_request, _) = requests.reject.next().await.unwrap();
1745
1746    cx.run_until_parked();
1747
1748    assert_eq!(
1749        &reject_request.rejections,
1750        &[EditPredictionRejection {
1751            request_id: first_id,
1752            reason: EditPredictionRejectReason::Canceled,
1753            was_shown: false,
1754            model_version: Some("zeta2:test-canceled".to_string()),
1755            e2e_latency_ms: None,
1756        }]
1757    );
1758}
1759
1760#[gpui::test]
1761async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
1762    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1763    let fs = FakeFs::new(cx.executor());
1764    fs.insert_tree(
1765        "/root",
1766        json!({
1767            "foo.md":  "Hello!\nHow\nBye\n"
1768        }),
1769    )
1770    .await;
1771    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1772
1773    let buffer = project
1774        .update(cx, |project, cx| {
1775            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1776            project.open_buffer(path, cx)
1777        })
1778        .await
1779        .unwrap();
1780    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1781    let position = snapshot.anchor_before(language::Point::new(1, 3));
1782
1783    // start two refresh tasks
1784    ep_store.update(cx, |ep_store, cx| {
1785        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1786    });
1787
1788    let (request1, respond_first) = requests.predict.next().await.unwrap();
1789
1790    ep_store.update(cx, |ep_store, cx| {
1791        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1792    });
1793
1794    let (request2, respond_second) = requests.predict.next().await.unwrap();
1795
1796    // wait for throttle, so requests are sent
1797    cx.run_until_parked();
1798
1799    ep_store.update(cx, |ep_store, cx| {
1800        // start a third request
1801        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1802
1803        // 2 are pending, so 2nd is cancelled
1804        assert_eq!(
1805            ep_store
1806                .get_or_init_project(&project, cx)
1807                .cancelled_predictions
1808                .iter()
1809                .copied()
1810                .collect::<Vec<_>>(),
1811            [1]
1812        );
1813    });
1814
1815    // wait for throttle
1816    cx.run_until_parked();
1817
1818    let (request3, respond_third) = requests.predict.next().await.unwrap();
1819
1820    let first_response = model_response(&request1, SIMPLE_DIFF);
1821    let first_id = first_response.request_id.clone();
1822    respond_first.send(first_response).unwrap();
1823
1824    cx.run_until_parked();
1825
1826    ep_store.update(cx, |ep_store, cx| {
1827        // current prediction is first
1828        assert_eq!(
1829            ep_store
1830                .prediction_at(&buffer, None, &project, cx)
1831                .unwrap()
1832                .id
1833                .0,
1834            first_id
1835        );
1836    });
1837
1838    let mut cancelled_response = model_response(&request2, SIMPLE_DIFF);
1839    cancelled_response.model_version = Some("zeta2:test-canceled-second".to_string());
1840    let cancelled_id = cancelled_response.request_id.clone();
1841    respond_second.send(cancelled_response).unwrap();
1842
1843    cx.run_until_parked();
1844
1845    ep_store.update(cx, |ep_store, cx| {
1846        // current prediction is still first, since second was cancelled
1847        assert_eq!(
1848            ep_store
1849                .prediction_at(&buffer, None, &project, cx)
1850                .unwrap()
1851                .id
1852                .0,
1853            first_id
1854        );
1855    });
1856
1857    let third_response = model_response(&request3, SIMPLE_DIFF);
1858    let third_response_id = third_response.request_id.clone();
1859    respond_third.send(third_response).unwrap();
1860
1861    cx.run_until_parked();
1862
1863    ep_store.update(cx, |ep_store, cx| {
1864        // third completes and replaces first
1865        assert_eq!(
1866            ep_store
1867                .prediction_at(&buffer, None, &project, cx)
1868                .unwrap()
1869                .id
1870                .0,
1871            third_response_id
1872        );
1873    });
1874
1875    // second is reported as rejected
1876    let (reject_request, _) = requests.reject.next().await.unwrap();
1877
1878    cx.run_until_parked();
1879
1880    assert_eq!(
1881        &reject_request.rejections,
1882        &[
1883            EditPredictionRejection {
1884                request_id: cancelled_id,
1885                reason: EditPredictionRejectReason::Canceled,
1886                was_shown: false,
1887                model_version: Some("zeta2:test-canceled-second".to_string()),
1888                e2e_latency_ms: None,
1889            },
1890            EditPredictionRejection {
1891                request_id: first_id,
1892                reason: EditPredictionRejectReason::Replaced,
1893                was_shown: false,
1894                model_version: None,
1895                // 2 throttle waits (for 2nd and 3rd requests) elapsed
1896                // between this request's start and response.
1897                e2e_latency_ms: Some(2 * EditPredictionStore::THROTTLE_TIMEOUT.as_millis()),
1898            }
1899        ]
1900    );
1901}
1902
1903#[gpui::test]
1904async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
1905    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1906
1907    let fs = FakeFs::new(cx.executor());
1908    fs.insert_tree(
1909        "/root",
1910        json!({
1911            "foo.md":  "Hello!\nHow\nBye\n",
1912            "bar.md": "Hola!\nComo\nAdios\n"
1913        }),
1914    )
1915    .await;
1916    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1917
1918    let buffer = project
1919        .update(cx, |project, cx| {
1920            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1921            project.set_active_path(Some(path.clone()), cx);
1922            project.open_buffer(path, cx)
1923        })
1924        .await
1925        .unwrap();
1926    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1927    let position = snapshot.anchor_before(language::Point::new(1, 3));
1928
1929    ep_store.update(cx, |ep_store, cx| {
1930        ep_store.register_project(&project, cx);
1931        ep_store.register_buffer(&buffer, &project, cx);
1932    });
1933
1934    // First edit request - no prior edit, so not throttled.
1935    ep_store.update(cx, |ep_store, cx| {
1936        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1937    });
1938    let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
1939    edit_response_tx.send(empty_response()).unwrap();
1940    cx.run_until_parked();
1941
1942    let diagnostic = lsp::Diagnostic {
1943        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1944        severity: Some(lsp::DiagnosticSeverity::ERROR),
1945        message: "Sentence is incomplete".to_string(),
1946        ..Default::default()
1947    };
1948
1949    // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
1950    project.update(cx, |project, cx| {
1951        project.lsp_store().update(cx, |lsp_store, cx| {
1952            lsp_store
1953                .update_diagnostics(
1954                    LanguageServerId(0),
1955                    lsp::PublishDiagnosticsParams {
1956                        uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
1957                        diagnostics: vec![diagnostic],
1958                        version: None,
1959                    },
1960                    None,
1961                    language::DiagnosticSourceKind::Pushed,
1962                    &[],
1963                    cx,
1964                )
1965                .unwrap();
1966        });
1967    });
1968    let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
1969    jump_response_tx.send(empty_response()).unwrap();
1970    cx.run_until_parked();
1971
1972    // Second edit request - should be throttled by the first edit.
1973    ep_store.update(cx, |ep_store, cx| {
1974        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1975    });
1976    assert_no_predict_request_ready(&mut requests.predict);
1977
1978    // Second jump request - should be throttled by the first jump.
1979    ep_store.update(cx, |ep_store, cx| {
1980        ep_store.refresh_prediction_from_diagnostics(
1981            project.clone(),
1982            DiagnosticSearchScope::Global,
1983            cx,
1984        );
1985    });
1986    assert_no_predict_request_ready(&mut requests.predict);
1987
1988    // Wait for both throttles to expire.
1989    cx.background_executor
1990        .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
1991    cx.background_executor.run_until_parked();
1992    cx.run_until_parked();
1993
1994    // Both requests should now go through.
1995    let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
1996    response_tx_1.send(empty_response()).unwrap();
1997    cx.run_until_parked();
1998
1999    let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
2000    response_tx_2.send(empty_response()).unwrap();
2001    cx.run_until_parked();
2002}
2003
2004#[gpui::test]
2005async fn test_same_frame_duplicate_requests_deduplicated(cx: &mut TestAppContext) {
2006    let (ep_store, mut requests) = init_test_with_fake_client(cx);
2007    let fs = FakeFs::new(cx.executor());
2008    fs.insert_tree(
2009        "/root",
2010        json!({
2011            "foo.md":  "Hello!\nHow\nBye\n"
2012        }),
2013    )
2014    .await;
2015    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2016
2017    let buffer = project
2018        .update(cx, |project, cx| {
2019            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2020            project.open_buffer(path, cx)
2021        })
2022        .await
2023        .unwrap();
2024    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2025    let position = snapshot.anchor_before(language::Point::new(1, 3));
2026
2027    // Enqueue two refresh calls in the same synchronous frame (no yielding).
2028    // Both `cx.spawn` tasks are created before either executes, so they both
2029    // capture the same `proceed_count_at_enqueue`. Only the first task should
2030    // pass the deduplication gate; the second should be skipped.
2031    ep_store.update(cx, |ep_store, cx| {
2032        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2033        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2034    });
2035
2036    // Let both spawned tasks run to completion (including any throttle waits).
2037    cx.run_until_parked();
2038
2039    // Exactly one prediction request should have been sent.
2040    let (request, respond_tx) = requests.predict.next().await.unwrap();
2041    respond_tx
2042        .send(model_response(&request, SIMPLE_DIFF))
2043        .unwrap();
2044    cx.run_until_parked();
2045
2046    // No second request should be pending.
2047    assert_no_predict_request_ready(&mut requests.predict);
2048}
2049
2050#[gpui::test]
2051async fn test_rejections_flushing(cx: &mut TestAppContext) {
2052    let (ep_store, mut requests) = init_test_with_fake_client(cx);
2053
2054    ep_store.update(cx, |ep_store, cx| {
2055        ep_store.reject_prediction(
2056            EditPredictionId("test-1".into()),
2057            EditPredictionRejectReason::Discarded,
2058            false,
2059            None,
2060            None,
2061            cx,
2062        );
2063        ep_store.reject_prediction(
2064            EditPredictionId("test-2".into()),
2065            EditPredictionRejectReason::Canceled,
2066            true,
2067            None,
2068            None,
2069            cx,
2070        );
2071    });
2072
2073    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
2074    cx.run_until_parked();
2075
2076    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
2077    respond_tx.send(()).unwrap();
2078
2079    // batched
2080    assert_eq!(reject_request.rejections.len(), 2);
2081    assert_eq!(
2082        reject_request.rejections[0],
2083        EditPredictionRejection {
2084            request_id: "test-1".to_string(),
2085            reason: EditPredictionRejectReason::Discarded,
2086            was_shown: false,
2087            model_version: None,
2088            e2e_latency_ms: None
2089        }
2090    );
2091    assert_eq!(
2092        reject_request.rejections[1],
2093        EditPredictionRejection {
2094            request_id: "test-2".to_string(),
2095            reason: EditPredictionRejectReason::Canceled,
2096            was_shown: true,
2097            model_version: None,
2098            e2e_latency_ms: None
2099        }
2100    );
2101
2102    // Reaching batch size limit sends without debounce
2103    ep_store.update(cx, |ep_store, cx| {
2104        for i in 0..70 {
2105            ep_store.reject_prediction(
2106                EditPredictionId(format!("batch-{}", i).into()),
2107                EditPredictionRejectReason::Discarded,
2108                false,
2109                None,
2110                None,
2111                cx,
2112            );
2113        }
2114    });
2115
2116    // First MAX/2 items are sent immediately
2117    cx.run_until_parked();
2118    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
2119    respond_tx.send(()).unwrap();
2120
2121    assert_eq!(reject_request.rejections.len(), 50);
2122    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
2123    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
2124
2125    // Remaining items are debounced with the next batch
2126    cx.executor().advance_clock(Duration::from_secs(15));
2127    cx.run_until_parked();
2128
2129    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
2130    respond_tx.send(()).unwrap();
2131
2132    assert_eq!(reject_request.rejections.len(), 20);
2133    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
2134    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
2135
2136    // Request failure
2137    ep_store.update(cx, |ep_store, cx| {
2138        ep_store.reject_prediction(
2139            EditPredictionId("retry-1".into()),
2140            EditPredictionRejectReason::Discarded,
2141            false,
2142            None,
2143            None,
2144            cx,
2145        );
2146    });
2147
2148    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
2149    cx.run_until_parked();
2150
2151    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
2152    assert_eq!(reject_request.rejections.len(), 1);
2153    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
2154    // Simulate failure
2155    drop(_respond_tx);
2156
2157    // Add another rejection
2158    ep_store.update(cx, |ep_store, cx| {
2159        ep_store.reject_prediction(
2160            EditPredictionId("retry-2".into()),
2161            EditPredictionRejectReason::Discarded,
2162            false,
2163            None,
2164            None,
2165            cx,
2166        );
2167    });
2168
2169    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
2170    cx.run_until_parked();
2171
2172    // Retry should include both the failed item and the new one
2173    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
2174    respond_tx.send(()).unwrap();
2175
2176    assert_eq!(reject_request.rejections.len(), 2);
2177    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
2178    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
2179}
2180
2181#[gpui::test]
2182fn test_active_buffer_diagnostics_fetching(cx: &mut TestAppContext) {
2183    let diagnostic_marker: TextRangeMarker = ('«', '»').into();
2184    let search_range_marker: TextRangeMarker = ('[', ']').into();
2185
2186    let (text, mut ranges) = marked_text_ranges_by(
2187        indoc! {r#"
2188            fn alpha() {
2189                let «first_value» = 1;
2190            }
2191
2192            [fn beta() {
2193                let «second_value» = 2;
2194                let third_value = second_value + missing_symbol;
2195            }ˇ]
2196
2197            fn gamma() {
2198                let «fourth_value» = missing_other_symbol;
2199            }
2200        "#},
2201        vec![diagnostic_marker.clone(), search_range_marker.clone()],
2202    );
2203
2204    let diagnostic_ranges = ranges.remove(&diagnostic_marker).unwrap_or_default();
2205    let search_ranges = ranges.remove(&search_range_marker).unwrap_or_default();
2206
2207    let buffer = cx.new(|cx| Buffer::local(&text, cx));
2208
2209    buffer.update(cx, |buffer, cx| {
2210        let snapshot = buffer.snapshot();
2211        let diagnostics = DiagnosticSet::new(
2212            diagnostic_ranges
2213                .iter()
2214                .enumerate()
2215                .map(|(index, range)| DiagnosticEntry {
2216                    range: snapshot.offset_to_point_utf16(range.start)
2217                        ..snapshot.offset_to_point_utf16(range.end),
2218                    diagnostic: Diagnostic {
2219                        severity: match index {
2220                            0 => DiagnosticSeverity::WARNING,
2221                            1 => DiagnosticSeverity::ERROR,
2222                            _ => DiagnosticSeverity::HINT,
2223                        },
2224                        message: match index {
2225                            0 => "first warning".to_string(),
2226                            1 => "second error".to_string(),
2227                            _ => "third hint".to_string(),
2228                        },
2229                        group_id: index + 1,
2230                        is_primary: true,
2231                        source_kind: language::DiagnosticSourceKind::Pushed,
2232                        ..Diagnostic::default()
2233                    },
2234                }),
2235            &snapshot,
2236        );
2237        buffer.update_diagnostics(LanguageServerId(0), diagnostics, cx);
2238    });
2239
2240    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2241    let search_range = snapshot.offset_to_point(search_ranges[0].start)
2242        ..snapshot.offset_to_point(search_ranges[0].end);
2243
2244    let active_buffer_diagnostics = zeta::active_buffer_diagnostics(&snapshot, search_range, 100);
2245
2246    assert_eq!(
2247        active_buffer_diagnostics,
2248        vec![zeta_prompt::ActiveBufferDiagnostic {
2249            severity: Some(1),
2250            message: "second error".to_string(),
2251            snippet: text,
2252            snippet_buffer_row_range: 5..5,
2253            diagnostic_range_in_snippet: 61..73,
2254        }]
2255    );
2256
2257    let buffer = cx.new(|cx| {
2258        Buffer::local(
2259            indoc! {"
2260                one
2261                two
2262                three
2263                four
2264                five
2265            "},
2266            cx,
2267        )
2268    });
2269
2270    buffer.update(cx, |buffer, cx| {
2271        let snapshot = buffer.snapshot();
2272        let diagnostics = DiagnosticSet::new(
2273            vec![
2274                DiagnosticEntry {
2275                    range: text::PointUtf16::new(0, 0)..text::PointUtf16::new(0, 3),
2276                    diagnostic: Diagnostic {
2277                        severity: DiagnosticSeverity::ERROR,
2278                        message: "row zero".to_string(),
2279                        group_id: 1,
2280                        is_primary: true,
2281                        source_kind: language::DiagnosticSourceKind::Pushed,
2282                        ..Diagnostic::default()
2283                    },
2284                },
2285                DiagnosticEntry {
2286                    range: text::PointUtf16::new(2, 0)..text::PointUtf16::new(2, 5),
2287                    diagnostic: Diagnostic {
2288                        severity: DiagnosticSeverity::WARNING,
2289                        message: "row two".to_string(),
2290                        group_id: 2,
2291                        is_primary: true,
2292                        source_kind: language::DiagnosticSourceKind::Pushed,
2293                        ..Diagnostic::default()
2294                    },
2295                },
2296                DiagnosticEntry {
2297                    range: text::PointUtf16::new(4, 0)..text::PointUtf16::new(4, 4),
2298                    diagnostic: Diagnostic {
2299                        severity: DiagnosticSeverity::INFORMATION,
2300                        message: "row four".to_string(),
2301                        group_id: 3,
2302                        is_primary: true,
2303                        source_kind: language::DiagnosticSourceKind::Pushed,
2304                        ..Diagnostic::default()
2305                    },
2306                },
2307            ],
2308            &snapshot,
2309        );
2310        buffer.update_diagnostics(LanguageServerId(0), diagnostics, cx);
2311    });
2312
2313    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2314
2315    let active_buffer_diagnostics =
2316        zeta::active_buffer_diagnostics(&snapshot, Point::new(2, 0)..Point::new(4, 0), 100);
2317
2318    assert_eq!(
2319        active_buffer_diagnostics
2320            .iter()
2321            .map(|diagnostic| (
2322                diagnostic.severity,
2323                diagnostic.message.clone(),
2324                diagnostic.snippet.clone(),
2325                diagnostic.snippet_buffer_row_range.clone(),
2326                diagnostic.diagnostic_range_in_snippet.clone(),
2327            ))
2328            .collect::<Vec<_>>(),
2329        vec![
2330            (
2331                Some(2),
2332                "row two".to_string(),
2333                "one\ntwo\nthree\nfour\nfive\n".to_string(),
2334                2..2,
2335                8..13,
2336            ),
2337            (
2338                Some(3),
2339                "row four".to_string(),
2340                "one\ntwo\nthree\nfour\nfive\n".to_string(),
2341                4..4,
2342                19..23,
2343            ),
2344        ]
2345    );
2346}
2347
2348// Generate a model response that would apply the given diff to the active file.
2349fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
2350    let editable_range =
2351        zeta_prompt::excerpt_range_for_format(Default::default(), &request.input.excerpt_ranges).1;
2352    let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string();
2353    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
2354
2355    PredictEditsV3Response {
2356        request_id: Uuid::new_v4().to_string(),
2357        editable_range,
2358        output: new_excerpt,
2359        model_version: None,
2360    }
2361}
2362
2363fn empty_response() -> PredictEditsV3Response {
2364    PredictEditsV3Response {
2365        request_id: Uuid::new_v4().to_string(),
2366        editable_range: 0..0,
2367        output: String::new(),
2368        model_version: None,
2369    }
2370}
2371
2372fn prompt_from_request(request: &PredictEditsV3Request) -> String {
2373    zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
2374        .expect("default zeta prompt formatting should succeed in edit prediction tests")
2375}
2376
2377fn assert_no_predict_request_ready(
2378    requests: &mut mpsc::UnboundedReceiver<(
2379        PredictEditsV3Request,
2380        oneshot::Sender<PredictEditsV3Response>,
2381    )>,
2382) {
2383    if requests.next().now_or_never().flatten().is_some() {
2384        panic!("Unexpected prediction request while throttled.");
2385    }
2386}
2387
2388struct RequestChannels {
2389    predict: mpsc::UnboundedReceiver<(
2390        PredictEditsV3Request,
2391        oneshot::Sender<PredictEditsV3Response>,
2392    )>,
2393    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
2394}
2395
2396fn init_test_with_fake_client(
2397    cx: &mut TestAppContext,
2398) -> (Entity<EditPredictionStore>, RequestChannels) {
2399    init_test_with_fake_client_and_legacy_data_collection(cx, None)
2400}
2401
2402fn init_test_with_fake_client_and_legacy_data_collection(
2403    cx: &mut TestAppContext,
2404    legacy_data_collection_choice: Option<&str>,
2405) -> (Entity<EditPredictionStore>, RequestChannels) {
2406    cx.update(move |cx| {
2407        cx.set_global(AppDatabase::test_new());
2408        let settings_store = SettingsStore::test(cx);
2409        cx.set_global(settings_store);
2410        zlog::init_test();
2411
2412        if let Some(legacy_data_collection_choice) = legacy_data_collection_choice {
2413            KeyValueStore::global(cx)
2414                .write_kvp(
2415                    ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2416                    legacy_data_collection_choice.to_string(),
2417                )
2418                .now_or_never()
2419                .expect("legacy data collection write should complete immediately")
2420                .expect("legacy data collection write should succeed");
2421        }
2422
2423        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
2424        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
2425
2426        let http_client = FakeHttpClient::create({
2427            move |req| {
2428                let uri = req.uri().path().to_string();
2429                let mut body = req.into_body();
2430                let predict_req_tx = predict_req_tx.clone();
2431                let reject_req_tx = reject_req_tx.clone();
2432                async move {
2433                    let resp = match uri.as_str() {
2434                        "/client/llm_tokens" => serde_json::to_string(&json!({
2435                            "token": "test"
2436                        }))
2437                        .unwrap(),
2438                        "/predict_edits/v3" => {
2439                            let mut buf = Vec::new();
2440                            body.read_to_end(&mut buf).await.ok();
2441                            let decompressed = zstd::decode_all(&buf[..]).unwrap();
2442                            let req = serde_json::from_slice(&decompressed).unwrap();
2443
2444                            let (res_tx, res_rx) = oneshot::channel();
2445                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
2446                            serde_json::to_string(&res_rx.await?).unwrap()
2447                        }
2448                        "/predict_edits/reject" => {
2449                            let mut buf = Vec::new();
2450                            body.read_to_end(&mut buf).await.ok();
2451                            let req = serde_json::from_slice(&buf).unwrap();
2452
2453                            let (res_tx, res_rx) = oneshot::channel();
2454                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
2455                            serde_json::to_string(&res_rx.await?).unwrap()
2456                        }
2457                        _ => {
2458                            panic!("Unexpected path: {}", uri)
2459                        }
2460                    };
2461
2462                    Ok(Response::builder().body(resp.into()).unwrap())
2463                }
2464            }
2465        });
2466
2467        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2468        client.cloud_client().set_credentials(1, "test".into());
2469
2470        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2471        language_model::init(cx);
2472        RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
2473        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
2474
2475        (
2476            ep_store,
2477            RequestChannels {
2478                predict: predict_req_rx,
2479                reject: reject_req_rx,
2480            },
2481        )
2482    })
2483}
2484
2485#[gpui::test]
2486async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
2487    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
2488    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
2489        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
2490    });
2491
2492    let edit_preview = cx
2493        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
2494        .await;
2495
2496    let prediction = EditPrediction {
2497        edits,
2498        cursor_position: None,
2499        edit_preview,
2500        buffer: buffer.clone(),
2501        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
2502        id: EditPredictionId("the-id".into()),
2503        inputs: ZetaPromptInput {
2504            events: Default::default(),
2505            related_files: Default::default(),
2506            active_buffer_diagnostics: vec![],
2507            cursor_path: Path::new("").into(),
2508            cursor_excerpt: "".into(),
2509            cursor_offset_in_excerpt: 0,
2510            excerpt_start_row: None,
2511            excerpt_ranges: Default::default(),
2512            syntax_ranges: None,
2513            in_open_source_repo: false,
2514            can_collect_data: false,
2515            repo_url: None,
2516        },
2517        model_version: None,
2518    };
2519
2520    cx.update(|cx| {
2521        assert_eq!(
2522            from_completion_edits(
2523                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2524                &buffer,
2525                cx
2526            ),
2527            vec![(2..5, "REM".into()), (9..11, "".into())]
2528        );
2529
2530        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
2531        assert_eq!(
2532            from_completion_edits(
2533                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2534                &buffer,
2535                cx
2536            ),
2537            vec![(2..2, "REM".into()), (6..8, "".into())]
2538        );
2539
2540        buffer.update(cx, |buffer, cx| buffer.undo(cx));
2541        assert_eq!(
2542            from_completion_edits(
2543                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2544                &buffer,
2545                cx
2546            ),
2547            vec![(2..5, "REM".into()), (9..11, "".into())]
2548        );
2549
2550        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
2551        assert_eq!(
2552            from_completion_edits(
2553                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2554                &buffer,
2555                cx
2556            ),
2557            vec![(3..3, "EM".into()), (7..9, "".into())]
2558        );
2559
2560        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
2561        assert_eq!(
2562            from_completion_edits(
2563                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2564                &buffer,
2565                cx
2566            ),
2567            vec![(4..4, "M".into()), (8..10, "".into())]
2568        );
2569
2570        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
2571        assert_eq!(
2572            from_completion_edits(
2573                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2574                &buffer,
2575                cx
2576            ),
2577            vec![(9..11, "".into())]
2578        );
2579
2580        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
2581        assert_eq!(
2582            from_completion_edits(
2583                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2584                &buffer,
2585                cx
2586            ),
2587            vec![(4..4, "M".into()), (8..10, "".into())]
2588        );
2589
2590        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
2591        assert_eq!(
2592            from_completion_edits(
2593                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
2594                &buffer,
2595                cx
2596            ),
2597            vec![(4..4, "M".into())]
2598        );
2599
2600        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
2601        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
2602    })
2603}
2604
2605#[gpui::test]
2606async fn test_clean_up_diff(cx: &mut TestAppContext) {
2607    init_test(cx);
2608
2609    assert_eq!(
2610        apply_edit_prediction(
2611            indoc! {"
2612                    fn main() {
2613                        let word_1 = \"lorem\";
2614                        let range = word.len()..word.len();
2615                    }
2616                "},
2617            indoc! {"
2618                    fn main() {
2619                        let word_1 = \"lorem\";
2620                        let range = word_1.len()..word_1.len();
2621                    }
2622                "},
2623            cx,
2624        )
2625        .await,
2626        indoc! {"
2627                fn main() {
2628                    let word_1 = \"lorem\";
2629                    let range = word_1.len()..word_1.len();
2630                }
2631            "},
2632    );
2633
2634    assert_eq!(
2635        apply_edit_prediction(
2636            indoc! {"
2637                    fn main() {
2638                        let story = \"the quick\"
2639                    }
2640                "},
2641            indoc! {"
2642                    fn main() {
2643                        let story = \"the quick brown fox jumps over the lazy dog\";
2644                    }
2645                "},
2646            cx,
2647        )
2648        .await,
2649        indoc! {"
2650                fn main() {
2651                    let story = \"the quick brown fox jumps over the lazy dog\";
2652                }
2653            "},
2654    );
2655}
2656
2657#[gpui::test]
2658async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
2659    init_test(cx);
2660
2661    let buffer_content = "lorem\n";
2662    let completion_response = "lorem\nipsum\n";
2663
2664    assert_eq!(
2665        apply_edit_prediction(buffer_content, completion_response, cx).await,
2666        "lorem\nipsum\n"
2667    );
2668}
2669
2670#[gpui::test]
2671async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
2672    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
2673    // When the buffer ends without a trailing newline, but the model returns output
2674    // with a trailing newline, zeta2 should normalize both sides before diffing
2675    // so no spurious newline is inserted.
2676    let (ep_store, mut requests) = init_test_with_fake_client(cx);
2677    let fs = FakeFs::new(cx.executor());
2678
2679    // Single line buffer with no trailing newline
2680    fs.insert_tree(
2681        "/root",
2682        json!({
2683            "foo.txt": "hello"
2684        }),
2685    )
2686    .await;
2687    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2688
2689    let buffer = project
2690        .update(cx, |project, cx| {
2691            let path = project
2692                .find_project_path(path!("root/foo.txt"), cx)
2693                .unwrap();
2694            project.open_buffer(path, cx)
2695        })
2696        .await
2697        .unwrap();
2698
2699    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2700    let position = snapshot.anchor_before(language::Point::new(0, 5));
2701
2702    ep_store.update(cx, |ep_store, cx| {
2703        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2704    });
2705
2706    let (request, respond_tx) = requests.predict.next().await.unwrap();
2707
2708    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
2709    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
2710    let excerpt_length = request.input.cursor_excerpt.len();
2711    let response = PredictEditsV3Response {
2712        request_id: Uuid::new_v4().to_string(),
2713        output: "hello world\n".to_string(),
2714        editable_range: 0..excerpt_length,
2715        model_version: None,
2716    };
2717    respond_tx.send(response).unwrap();
2718
2719    cx.run_until_parked();
2720
2721    // The prediction should insert " world" without adding a newline
2722    ep_store.update(cx, |ep_store, cx| {
2723        let prediction = ep_store
2724            .prediction_at(&buffer, None, &project, cx)
2725            .expect("should have prediction");
2726        let edits: Vec<_> = prediction
2727            .edits
2728            .iter()
2729            .map(|(range, text)| {
2730                let snapshot = buffer.read(cx).snapshot();
2731                (range.to_offset(&snapshot), text.clone())
2732            })
2733            .collect();
2734        assert_eq!(edits, vec![(5..5, " world".into())]);
2735    });
2736}
2737
2738#[gpui::test]
2739async fn test_v3_prediction_strips_cursor_marker_from_edit_text(cx: &mut TestAppContext) {
2740    let (ep_store, mut requests) = init_test_with_fake_client(cx);
2741    let fs = FakeFs::new(cx.executor());
2742
2743    fs.insert_tree(
2744        "/root",
2745        json!({
2746            "foo.txt": "hello"
2747        }),
2748    )
2749    .await;
2750    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2751
2752    let buffer = project
2753        .update(cx, |project, cx| {
2754            let path = project
2755                .find_project_path(path!("root/foo.txt"), cx)
2756                .unwrap();
2757            project.open_buffer(path, cx)
2758        })
2759        .await
2760        .unwrap();
2761
2762    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2763    let position = snapshot.anchor_before(language::Point::new(0, 5));
2764
2765    ep_store.update(cx, |ep_store, cx| {
2766        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
2767    });
2768
2769    let (request, respond_tx) = requests.predict.next().await.unwrap();
2770    let excerpt_length = request.input.cursor_excerpt.len();
2771    respond_tx
2772        .send(PredictEditsV3Response {
2773            request_id: Uuid::new_v4().to_string(),
2774            output: "hello<|user_cursor|> world".to_string(),
2775            editable_range: 0..excerpt_length,
2776            model_version: None,
2777        })
2778        .unwrap();
2779
2780    cx.run_until_parked();
2781
2782    ep_store.update(cx, |ep_store, cx| {
2783        let prediction = ep_store
2784            .prediction_at(&buffer, None, &project, cx)
2785            .expect("should have prediction");
2786        let snapshot = buffer.read(cx).snapshot();
2787        let edits: Vec<_> = prediction
2788            .edits
2789            .iter()
2790            .map(|(range, text)| (range.to_offset(&snapshot), text.clone()))
2791            .collect();
2792
2793        assert_eq!(edits, vec![(5..5, " world".into())]);
2794    });
2795}
2796
2797fn init_test(cx: &mut TestAppContext) {
2798    cx.update(|cx| {
2799        cx.set_global(AppDatabase::test_new());
2800        let settings_store = SettingsStore::test(cx);
2801        cx.set_global(settings_store);
2802    });
2803}
2804
2805async fn apply_edit_prediction(
2806    buffer_content: &str,
2807    completion_response: &str,
2808    cx: &mut TestAppContext,
2809) -> String {
2810    let fs = project::FakeFs::new(cx.executor());
2811    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2812    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2813    let (ep_store, response) = make_test_ep_store(&project, cx).await;
2814    *response.lock() = completion_response.to_string();
2815    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2816    buffer.update(cx, |buffer, cx| {
2817        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2818    });
2819    buffer.read_with(cx, |buffer, _| buffer.text())
2820}
2821
2822async fn run_edit_prediction(
2823    buffer: &Entity<Buffer>,
2824    project: &Entity<Project>,
2825    ep_store: &Entity<EditPredictionStore>,
2826    cx: &mut TestAppContext,
2827) -> EditPrediction {
2828    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2829    ep_store.update(cx, |ep_store, cx| {
2830        ep_store.register_buffer(buffer, &project, cx)
2831    });
2832    cx.background_executor.run_until_parked();
2833    let prediction_task = ep_store.update(cx, |ep_store, cx| {
2834        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2835    });
2836    prediction_task.await.unwrap().unwrap().prediction.unwrap()
2837}
2838
2839async fn make_test_ep_store(
2840    project: &Entity<Project>,
2841    cx: &mut TestAppContext,
2842) -> (Entity<EditPredictionStore>, Arc<Mutex<String>>) {
2843    let default_response = "hello world\n".to_string();
2844    let completion_response: Arc<Mutex<String>> = Arc::new(Mutex::new(default_response));
2845    let http_client = FakeHttpClient::create({
2846        let completion_response = completion_response.clone();
2847        let mut next_request_id = 0;
2848        move |req| {
2849            let completion_response = completion_response.clone();
2850            let method = req.method().clone();
2851            let uri = req.uri().path().to_string();
2852            let mut body = req.into_body();
2853            async move {
2854                match (method, uri.as_str()) {
2855                    (Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2856                        .status(200)
2857                        .body(
2858                            serde_json::to_string(&CreateLlmTokenResponse {
2859                                token: LlmToken("the-llm-token".to_string()),
2860                            })
2861                            .unwrap()
2862                            .into(),
2863                        )
2864                        .unwrap()),
2865                    (Method::POST, "/predict_edits/v3") => {
2866                        let mut buf = Vec::new();
2867                        body.read_to_end(&mut buf).await.ok();
2868                        let decompressed = zstd::decode_all(&buf[..]).unwrap();
2869                        let req: PredictEditsV3Request =
2870                            serde_json::from_slice(&decompressed).unwrap();
2871
2872                        next_request_id += 1;
2873                        Ok(http_client::Response::builder()
2874                            .status(200)
2875                            .body(
2876                                serde_json::to_string(&PredictEditsV3Response {
2877                                    request_id: format!("request-{next_request_id}"),
2878                                    editable_range: 0..req.input.cursor_excerpt.len(),
2879                                    output: completion_response.lock().clone(),
2880                                    model_version: None,
2881                                })
2882                                .unwrap()
2883                                .into(),
2884                            )
2885                            .unwrap())
2886                    }
2887                    _ => Ok(http_client::Response::builder()
2888                        .status(404)
2889                        .body("Not Found".to_string().into())
2890                        .unwrap()),
2891                }
2892            }
2893        }
2894    });
2895
2896    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2897    let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
2898    cx.update(|cx| {
2899        RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
2900    });
2901    let _server = FakeServer::for_client(42, &client, cx).await;
2902
2903    let ep_store = cx.new(|cx| {
2904        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2905        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
2906
2907        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2908        for worktree in worktrees {
2909            let worktree_id = worktree.read(cx).id();
2910            ep_store
2911                .get_or_init_project(project, cx)
2912                .license_detection_watchers
2913                .entry(worktree_id)
2914                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2915        }
2916
2917        ep_store
2918    });
2919
2920    (ep_store, completion_response)
2921}
2922
2923fn to_completion_edits(
2924    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2925    buffer: &Entity<Buffer>,
2926    cx: &App,
2927) -> Vec<(Range<Anchor>, Arc<str>)> {
2928    let buffer = buffer.read(cx);
2929    iterator
2930        .into_iter()
2931        .map(|(range, text)| {
2932            (
2933                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2934                text,
2935            )
2936        })
2937        .collect()
2938}
2939
2940fn from_completion_edits(
2941    editor_edits: &[(Range<Anchor>, Arc<str>)],
2942    buffer: &Entity<Buffer>,
2943    cx: &App,
2944) -> Vec<(Range<usize>, Arc<str>)> {
2945    let buffer = buffer.read(cx);
2946    editor_edits
2947        .iter()
2948        .map(|(range, text)| {
2949            (
2950                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2951                text.clone(),
2952            )
2953        })
2954        .collect()
2955}
2956
2957#[gpui::test]
2958async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2959    init_test(cx);
2960
2961    let fs = FakeFs::new(cx.executor());
2962    fs.insert_tree(
2963        "/project",
2964        serde_json::json!({
2965            "main.rs": "fn main() {\n    \n}\n"
2966        }),
2967    )
2968    .await;
2969
2970    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2971
2972    let http_client = FakeHttpClient::create(|_req| async move {
2973        Ok(gpui::http_client::Response::builder()
2974            .status(401)
2975            .body("Unauthorized".into())
2976            .unwrap())
2977    });
2978
2979    let client =
2980        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2981    let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
2982    cx.update(|cx| {
2983        RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
2984    });
2985
2986    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2987
2988    let buffer = project
2989        .update(cx, |project, cx| {
2990            let path = project
2991                .find_project_path(path!("/project/main.rs"), cx)
2992                .unwrap();
2993            project.open_buffer(path, cx)
2994        })
2995        .await
2996        .unwrap();
2997
2998    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2999    ep_store.update(cx, |ep_store, cx| {
3000        ep_store.register_buffer(&buffer, &project, cx)
3001    });
3002    cx.background_executor.run_until_parked();
3003
3004    let completion_task = ep_store.update(cx, |ep_store, cx| {
3005        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta);
3006        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
3007    });
3008
3009    let result = completion_task.await;
3010    assert!(
3011        result.is_err(),
3012        "Without authentication and without custom URL, prediction should fail"
3013    );
3014}
3015
3016#[gpui::test]
3017async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
3018    fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
3019        let collab_replica = clock::ReplicaId::new(10);
3020        let anchor = buffer.read_with(cx, |buffer, _| {
3021            buffer.snapshot().anchor_before(Point::new(row, 0))
3022        });
3023        let selections: Arc<[Selection<Anchor>]> = Arc::new([Selection {
3024            id: 1,
3025            start: anchor,
3026            end: anchor,
3027            reversed: false,
3028            goal: SelectionGoal::None,
3029        }]);
3030        buffer.update(cx, |buffer, cx| {
3031            buffer.apply_ops(
3032                [Operation::UpdateSelections {
3033                    selections,
3034                    lamport_timestamp: clock::Lamport {
3035                        replica_id: collab_replica,
3036                        value: 1,
3037                    },
3038                    line_mode: false,
3039                    cursor_shape: CursorShape::Bar,
3040                }],
3041                cx,
3042            );
3043        });
3044    }
3045
3046    fn publish_diagnostics(
3047        uri_path: &'static str,
3048        rows: &[u32],
3049        project: &Entity<Project>,
3050        cx: &mut TestAppContext,
3051    ) {
3052        let diagnostics: Vec<_> = rows
3053            .iter()
3054            .map(|&row| lsp::Diagnostic {
3055                range: lsp::Range::new(lsp::Position::new(row, 0), lsp::Position::new(row, 5)),
3056                severity: Some(lsp::DiagnosticSeverity::ERROR),
3057                message: format!("error at row {row}"),
3058                ..Default::default()
3059            })
3060            .collect();
3061        project.update(cx, |project, cx| {
3062            project.lsp_store().update(cx, |lsp_store, cx| {
3063                lsp_store
3064                    .update_diagnostics(
3065                        LanguageServerId(0),
3066                        lsp::PublishDiagnosticsParams {
3067                            uri: lsp::Uri::from_file_path(uri_path).expect("invalid uri"),
3068                            diagnostics,
3069                            version: None,
3070                        },
3071                        None,
3072                        language::DiagnosticSourceKind::Pushed,
3073                        &[],
3074                        cx,
3075                    )
3076                    .expect("failed to update diagnostics");
3077            });
3078        });
3079    }
3080
3081    init_test(cx);
3082
3083    let mut lines = String::new();
3084    for i in 0..60 {
3085        lines.push_str(&format!("line {i}\n"));
3086    }
3087
3088    let fs = FakeFs::new(cx.executor());
3089    fs.insert_tree(
3090        "/root",
3091        json!({
3092            "active.txt": lines,
3093            "collab_file.txt": "error here\nsecond line\n",
3094            "free_file.txt": "another error\nsecond line\n",
3095        }),
3096    )
3097    .await;
3098    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3099
3100    let active_buffer = project
3101        .update(cx, |project, cx| {
3102            let path = project
3103                .find_project_path(path!("/root/active.txt"), cx)
3104                .expect("active.txt not found");
3105            project.set_active_path(Some(path.clone()), cx);
3106            project.open_buffer(path, cx)
3107        })
3108        .await
3109        .expect("failed to open active buffer");
3110
3111    set_collaborator_cursor(&active_buffer, 5, cx);
3112
3113    publish_diagnostics(path!("/root/active.txt"), &[3, 25, 50], &project, cx);
3114
3115    cx.run_until_parked();
3116
3117    let cursor_point = Point::new(25, 0);
3118    let empty_search_range: Range<Point> = Default::default();
3119
3120    let snapshot = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
3121    let result = EditPredictionStore::next_diagnostic_location(
3122        active_buffer.clone(),
3123        &snapshot,
3124        empty_search_range.clone(),
3125        cursor_point,
3126        &project,
3127        &mut cx.to_async(),
3128    )
3129    .await
3130    .expect("next_diagnostic_location failed");
3131
3132    let (result_buffer, result_anchor) = result.expect("expected a diagnostic location");
3133    assert_eq!(result_buffer.entity_id(), active_buffer.entity_id());
3134    let result_row = result_buffer.read_with(cx, |buffer, _| {
3135        result_anchor.to_point(&buffer.snapshot()).row
3136    });
3137    assert_ne!(
3138        result_row, 3,
3139        "row 3 is near collaborator (row 5) but far from local cursor (row 25), should be excluded"
3140    );
3141    assert!(
3142        result_row == 25 || result_row == 50,
3143        "expected row 25 or 50, got {result_row}"
3144    );
3145
3146    let snapshot_near = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
3147    let near_cursor_point = Point::new(4, 0);
3148    let result_near = EditPredictionStore::next_diagnostic_location(
3149        active_buffer.clone(),
3150        &snapshot_near,
3151        empty_search_range.clone(),
3152        near_cursor_point,
3153        &project,
3154        &mut cx.to_async(),
3155    )
3156    .await
3157    .expect("next_diagnostic_location failed");
3158
3159    let (_, near_anchor) = result_near.expect("expected a diagnostic location when both are near");
3160    let near_row =
3161        active_buffer.read_with(cx, |buffer, _| near_anchor.to_point(&buffer.snapshot()).row);
3162    assert_eq!(
3163        near_row, 3,
3164        "row 3 should be included when local cursor (row 4) is also near the collaborator"
3165    );
3166
3167    let snapshot_far = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
3168    let far_cursor_point = Point::new(50, 0);
3169    let result_far = EditPredictionStore::next_diagnostic_location(
3170        active_buffer.clone(),
3171        &snapshot_far,
3172        empty_search_range.clone(),
3173        far_cursor_point,
3174        &project,
3175        &mut cx.to_async(),
3176    )
3177    .await
3178    .expect("next_diagnostic_location failed");
3179
3180    let (_, far_anchor) = result_far.expect("expected a diagnostic location");
3181    let far_row =
3182        active_buffer.read_with(cx, |buffer, _| far_anchor.to_point(&buffer.snapshot()).row);
3183    assert_eq!(
3184        far_row, 50,
3185        "row 50 is near local cursor (row 50) and far from collaborator, should be picked"
3186    );
3187
3188    publish_diagnostics(path!("/root/collab_file.txt"), &[0], &project, cx);
3189    publish_diagnostics(path!("/root/free_file.txt"), &[0], &project, cx);
3190    cx.run_until_parked();
3191
3192    let collab_buffer = project
3193        .update(cx, |project, cx| {
3194            let path = project
3195                .find_project_path(path!("/root/collab_file.txt"), cx)
3196                .expect("collab_file.txt not found");
3197            project.open_buffer(path, cx)
3198        })
3199        .await
3200        .expect("failed to open collab buffer");
3201
3202    set_collaborator_cursor(&collab_buffer, 0, cx);
3203    cx.run_until_parked();
3204
3205    let no_same_file_search_range = Point::new(0, 0)..Point::new(59, 0);
3206    let snapshot_cross = active_buffer.read_with(cx, |buffer, _| buffer.snapshot());
3207    let result_cross = EditPredictionStore::next_diagnostic_location(
3208        active_buffer.clone(),
3209        &snapshot_cross,
3210        no_same_file_search_range,
3211        Point::new(0, 0),
3212        &project,
3213        &mut cx.to_async(),
3214    )
3215    .await
3216    .expect("cross-file next_diagnostic_location failed");
3217
3218    let (cross_buffer, _) = result_cross.expect("expected a cross-file diagnostic location");
3219    let cross_path = cross_buffer.read_with(cx, |buffer, cx| {
3220        buffer
3221            .file()
3222            .expect("buffer should have a file")
3223            .full_path(cx)
3224    });
3225    assert_eq!(
3226        cross_path,
3227        Path::new(path!("root/free_file.txt")),
3228        "should skip collab_file.txt (has collaborator) and pick free_file.txt"
3229    );
3230}
3231
3232#[gpui::test]
3233async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
3234    let (ep_store, _requests) = init_test_with_fake_client(cx);
3235    let fs = FakeFs::new(cx.executor());
3236
3237    // Buffer with two clearly separated regions:
3238    //   Region A = lines 0-9   (offsets 0..50)
3239    //   Region B = lines 20-29 (offsets 105..155)
3240    // A big gap in between so edits in one region never overlap the other.
3241    let mut content = String::new();
3242    for i in 0..30 {
3243        content.push_str(&format!("line {i:02}\n"));
3244    }
3245
3246    fs.insert_tree(
3247        "/root",
3248        json!({
3249            "foo.md": content.clone()
3250        }),
3251    )
3252    .await;
3253    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3254
3255    let buffer = project
3256        .update(cx, |project, cx| {
3257            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3258            project.open_buffer(path, cx)
3259        })
3260        .await
3261        .unwrap();
3262
3263    type SettledEventRecord = (EditPredictionId, String);
3264    let settled_events: Arc<Mutex<Vec<SettledEventRecord>>> = Arc::new(Mutex::new(Vec::new()));
3265
3266    ep_store.update(cx, |ep_store, cx| {
3267        ep_store.register_buffer(&buffer, &project, cx);
3268
3269        let settled_events = settled_events.clone();
3270        ep_store.settled_event_callback = Some(Box::new(move |id, text| {
3271            settled_events.lock().push((id, text));
3272        }));
3273    });
3274
3275    // --- Phase 1: edit in region A and enqueue prediction A ---
3276
3277    buffer.update(cx, |buffer, cx| {
3278        // Edit at the start of line 0.
3279        buffer.edit(vec![(0..0, "ADDED ")], None, cx);
3280    });
3281    cx.run_until_parked();
3282
3283    let snapshot_a = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3284    let empty_edits: Arc<[(Range<Anchor>, Arc<str>)]> = Vec::new().into();
3285    let edit_preview_a = buffer
3286        .read_with(cx, |buffer, cx| {
3287            buffer.preview_edits(empty_edits.clone(), cx)
3288        })
3289        .await;
3290
3291    // Region A: first 10 lines of the buffer.
3292    let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0));
3293
3294    ep_store.update(cx, |ep_store, cx| {
3295        ep_store.enqueue_settled_prediction(
3296            EditPredictionId("prediction-a".into()),
3297            &project,
3298            &buffer,
3299            &snapshot_a,
3300            editable_region_a.clone(),
3301            &edit_preview_a,
3302            None,
3303            Duration::from_secs(0),
3304            cx,
3305        );
3306    });
3307
3308    // --- Phase 2: repeatedly edit in region A to keep it unsettled ---
3309
3310    // Let the worker process the channel message before we start advancing.
3311    cx.run_until_parked();
3312
3313    let mut region_a_edit_offset = 5;
3314    for _ in 0..3 {
3315        // Edit inside region A (not at the boundary) so `last_edit_at` is
3316        // updated before the worker's next wake.
3317        buffer.update(cx, |buffer, cx| {
3318            buffer.edit(
3319                vec![(region_a_edit_offset..region_a_edit_offset, "x")],
3320                None,
3321                cx,
3322            );
3323        });
3324        region_a_edit_offset += 1;
3325        cx.run_until_parked();
3326
3327        cx.executor()
3328            .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 2);
3329        cx.run_until_parked();
3330        assert!(
3331            settled_events.lock().is_empty(),
3332            "no settled events should fire while region A is still being edited"
3333        );
3334    }
3335
3336    // Still nothing settled.
3337    assert!(settled_events.lock().is_empty());
3338
3339    // --- Phase 3: edit in distinct region B, enqueue prediction B ---
3340    // Advance a small amount so B's quiescence window starts later than A's,
3341    // but not so much that A settles (A's last edit was at the start of
3342    // iteration 3, and it needs a full Q to settle).
3343    cx.executor()
3344        .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
3345    cx.run_until_parked();
3346    assert!(settled_events.lock().is_empty());
3347
3348    let snapshot_b = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3349    let line_20_offset = snapshot_b.point_to_offset(Point::new(20, 0));
3350
3351    buffer.update(cx, |buffer, cx| {
3352        buffer.edit(vec![(line_20_offset..line_20_offset, "NEW ")], None, cx);
3353    });
3354    cx.run_until_parked();
3355
3356    let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3357    let edit_preview_b = buffer
3358        .read_with(cx, |buffer, cx| buffer.preview_edits(empty_edits, cx))
3359        .await;
3360    let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0));
3361
3362    ep_store.update(cx, |ep_store, cx| {
3363        ep_store.enqueue_settled_prediction(
3364            EditPredictionId("prediction-b".into()),
3365            &project,
3366            &buffer,
3367            &snapshot_b2,
3368            editable_region_b.clone(),
3369            &edit_preview_b,
3370            None,
3371            Duration::from_secs(0),
3372            cx,
3373        );
3374    });
3375
3376    cx.run_until_parked();
3377    assert!(
3378        settled_events.lock().is_empty(),
3379        "neither prediction should have settled yet"
3380    );
3381
3382    // --- Phase 4: let enough time pass for region A to settle ---
3383    // A's last edit was at T_a (during the last loop iteration). The worker is
3384    // sleeping until T_a + Q. We advance just enough to reach that wake time
3385    // (Q/4 since we already advanced Q/4 in phase 3 on top of the loop's
3386    // 3*Q/2). At that point A has been quiet for Q and settles, but B was
3387    // enqueued only Q/4 ago and stays pending.
3388    cx.executor()
3389        .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE / 4);
3390    cx.run_until_parked();
3391
3392    {
3393        let events = settled_events.lock().clone();
3394        assert_eq!(
3395            events.len(),
3396            1,
3397            "prediction and capture_sample for A should have settled, got: {events:?}"
3398        );
3399        assert_eq!(events[0].0, EditPredictionId("prediction-a".into()));
3400    }
3401
3402    // --- Phase 5: let more time pass for region B to settle ---
3403    // B's last edit was Q/4 before A settled. The worker rescheduled to
3404    // B's last_edit_at + Q, which is 3Q/4 from now.
3405    cx.executor()
3406        .advance_clock(EDIT_PREDICTION_SETTLED_QUIESCENCE * 3 / 4);
3407    cx.run_until_parked();
3408
3409    {
3410        let events = settled_events.lock().clone();
3411        assert_eq!(
3412            events.len(),
3413            2,
3414            "both prediction and capture_sample settled events should be emitted for each request, got: {events:?}"
3415        );
3416        assert_eq!(events[1].0, EditPredictionId("prediction-b".into()));
3417    }
3418}
3419
3420#[gpui::test]
3421async fn test_data_collection_disabled_by_default(cx: &mut TestAppContext) {
3422    let (ep_store, _channels) = init_test_with_fake_client(cx);
3423
3424    cx.update(|cx| {
3425        assert!(!ep_store.read(cx).is_data_collection_enabled(cx));
3426    });
3427}
3428
3429#[gpui::test]
3430async fn test_data_collection_enabled_via_legacy_kv_store(cx: &mut TestAppContext) {
3431    let (ep_store, _channels) =
3432        init_test_with_fake_client_and_legacy_data_collection(cx, Some("true"));
3433
3434    cx.update(|cx| {
3435        assert!(ep_store.read(cx).is_data_collection_enabled(cx));
3436    });
3437}
3438
3439#[gpui::test]
3440async fn test_data_collection_default_uses_cached_legacy_value(cx: &mut TestAppContext) {
3441    let (ep_store, _channels) =
3442        init_test_with_fake_client_and_legacy_data_collection(cx, Some("true"));
3443
3444    cx.update(|cx| {
3445        assert!(ep_store.read(cx).is_data_collection_enabled(cx));
3446    });
3447
3448    cx.update(|cx| KeyValueStore::global(cx))
3449        .delete_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into())
3450        .await
3451        .unwrap();
3452
3453    cx.update(|cx| {
3454        assert!(ep_store.read(cx).is_data_collection_enabled(cx));
3455    });
3456}
3457
3458#[gpui::test]
3459async fn test_data_collection_setting_overrides_kv_store(cx: &mut TestAppContext) {
3460    let (ep_store, _channels) =
3461        init_test_with_fake_client_and_legacy_data_collection(cx, Some("true"));
3462
3463    // An explicit false in settings.json wins over the KV store.
3464    cx.update_global::<SettingsStore, _>(|settings, cx| {
3465        settings.update_user_settings(cx, |content| {
3466            content
3467                .project
3468                .all_languages
3469                .edit_predictions
3470                .get_or_insert_default()
3471                .allow_data_collection = Some(EditPredictionDataCollectionChoice::No);
3472        });
3473    });
3474
3475    cx.update(|cx| {
3476        assert!(!ep_store.read(cx).is_data_collection_enabled(cx));
3477    });
3478}
3479
3480#[gpui::test]
3481async fn test_data_collection_enabled_via_setting(cx: &mut TestAppContext) {
3482    let (ep_store, _channels) = init_test_with_fake_client(cx);
3483
3484    cx.update_global::<SettingsStore, _>(|settings, cx| {
3485        settings.update_user_settings(cx, |content| {
3486            content
3487                .project
3488                .all_languages
3489                .edit_predictions
3490                .get_or_insert_default()
3491                .allow_data_collection = Some(EditPredictionDataCollectionChoice::Yes);
3492        });
3493    });
3494
3495    cx.update(|cx| {
3496        assert!(ep_store.read(cx).is_data_collection_enabled(cx));
3497    });
3498}
3499
3500#[gpui::test]
3501async fn test_data_collection_always_enabled_for_staff(cx: &mut TestAppContext) {
3502    let (ep_store, _channels) = init_test_with_fake_client(cx);
3503
3504    cx.update(|cx| {
3505        cx.set_staff(true);
3506        assert!(ep_store.read(cx).is_data_collection_enabled(cx));
3507    });
3508}
3509
3510#[gpui::test]
3511async fn test_data_collection_disabled_by_organization_configuration(cx: &mut TestAppContext) {
3512    let (ep_store, _channels) = init_test_with_fake_client(cx);
3513
3514    cx.update_global::<SettingsStore, _>(|settings, cx| {
3515        settings.update_user_settings(cx, |content| {
3516            content
3517                .project
3518                .all_languages
3519                .edit_predictions
3520                .get_or_insert_default()
3521                .allow_data_collection = Some(EditPredictionDataCollectionChoice::Yes);
3522        });
3523    });
3524
3525    let user_store = cx.update(|cx| ep_store.read(cx).user_store.clone());
3526    cx.update(|cx| {
3527        user_store.update(cx, |user_store, cx| {
3528            user_store.set_current_organization_configuration_for_test(
3529                Arc::new(Organization {
3530                    id: OrganizationId("org-1".into()),
3531                    name: "Org 1".into(),
3532                    is_personal: false,
3533                }),
3534                OrganizationConfiguration {
3535                    is_zed_model_provider_enabled: true,
3536                    is_agent_thread_feedback_enabled: true,
3537                    is_collaboration_enabled: true,
3538                    edit_prediction: OrganizationEditPredictionConfiguration {
3539                        is_enabled: true,
3540                        is_feedback_enabled: false,
3541                    },
3542                },
3543                cx,
3544            );
3545        });
3546
3547        assert!(!ep_store.read(cx).is_data_collection_enabled(cx));
3548    });
3549}
3550
3551// When a user had data collection enabled via the legacy KV store (with no explicit
3552// setting in settings.json), toggle_data_collection must read the *resolved* state
3553// (true) and write Some(false).
3554#[gpui::test]
3555async fn test_toggle_data_collection_from_kv_enabled_state(cx: &mut TestAppContext) {
3556    let (ep_store, _channels) =
3557        init_test_with_fake_client_and_legacy_data_collection(cx, Some("true"));
3558
3559    cx.update(|cx| {
3560        assert!(
3561            ep_store.read(cx).is_data_collection_enabled(cx),
3562            "data collection should be enabled via KV store before toggle"
3563        );
3564    });
3565
3566    // Simulate what toggle_data_collection does: capture the resolved current
3567    // state, then write its inverse.
3568    let is_currently_enabled = cx.update(|cx| ep_store.read(cx).is_data_collection_enabled(cx));
3569    cx.update_global::<SettingsStore, _>(|settings, cx| {
3570        settings.update_user_settings(cx, |content| {
3571            content
3572                .project
3573                .all_languages
3574                .edit_predictions
3575                .get_or_insert_default()
3576                .allow_data_collection = Some(if is_currently_enabled {
3577                EditPredictionDataCollectionChoice::No
3578            } else {
3579                EditPredictionDataCollectionChoice::Yes
3580            });
3581        });
3582    });
3583
3584    cx.update(|cx| {
3585        assert!(
3586            !ep_store.read(cx).is_data_collection_enabled(cx),
3587            "data collection should be disabled after toggling off from KV-enabled state"
3588        );
3589    });
3590}
3591
3592#[gpui::test]
3593async fn test_upsell_shown_by_default(cx: &mut TestAppContext) {
3594    init_test(cx);
3595    let kvp = cx.update(|cx| KeyValueStore::global(cx));
3596    kvp.delete_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into())
3597        .await
3598        .ok();
3599    kvp.delete_kvp(ZedPredictUpsell::KEY.into()).await.ok();
3600
3601    cx.update(|cx| assert!(should_show_upsell_modal(cx)));
3602}
3603
3604#[gpui::test]
3605async fn test_upsell_dismissed_when_data_collection_choice_in_kv_store(cx: &mut TestAppContext) {
3606    init_test(cx);
3607
3608    // Any value for the data collection key means the old upsell was already
3609    // shown, regardless of whether data collection was accepted or declined.
3610    for value in &["true", "false"] {
3611        cx.update(|cx| KeyValueStore::global(cx))
3612            .write_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into(), value.to_string())
3613            .await
3614            .unwrap();
3615
3616        cx.update(|cx| {
3617            assert!(
3618                !should_show_upsell_modal(cx),
3619                "upsell should be suppressed when data collection choice is '{value}'"
3620            );
3621        });
3622    }
3623
3624    cx.update(|cx| KeyValueStore::global(cx))
3625        .delete_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into())
3626        .await
3627        .unwrap();
3628}
3629
3630#[gpui::test]
3631async fn test_upsell_dismissed_when_dismissed_key_set(cx: &mut TestAppContext) {
3632    init_test(cx);
3633    let kvp = cx.update(|cx| KeyValueStore::global(cx));
3634    kvp.delete_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into())
3635        .await
3636        .ok();
3637    kvp.write_kvp(ZedPredictUpsell::KEY.into(), "1".into())
3638        .await
3639        .unwrap();
3640
3641    cx.update(|cx| assert!(!should_show_upsell_modal(cx)));
3642
3643    kvp.delete_kvp(ZedPredictUpsell::KEY.into()).await.unwrap();
3644}
3645
3646#[gpui::test]
3647async fn test_upsell_dismissed_via_dismissable_api(cx: &mut TestAppContext) {
3648    init_test(cx);
3649    let kvp = cx.update(|cx| KeyValueStore::global(cx));
3650    kvp.delete_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE.into())
3651        .await
3652        .ok();
3653    kvp.delete_kvp(ZedPredictUpsell::KEY.into()).await.ok();
3654
3655    cx.update(|cx| {
3656        assert!(should_show_upsell_modal(cx));
3657        ZedPredictUpsell::set_dismissed(true, cx);
3658    });
3659    cx.run_until_parked();
3660
3661    cx.update(|cx| assert!(!should_show_upsell_modal(cx)));
3662
3663    kvp.delete_kvp(ZedPredictUpsell::KEY.into()).await.unwrap();
3664}
3665
3666#[ctor::ctor]
3667fn init_logger() {
3668    zlog::init_test();
3669}