edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
   3use client::{UserStore, test::FakeServer};
   4use clock::{FakeSystemClock, ReplicaId};
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
   8    RejectEditPredictionsBody,
   9    predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
  10};
  11use futures::{
  12    AsyncReadExt, StreamExt,
  13    channel::{mpsc, oneshot},
  14};
  15use gpui::{
  16    Entity, TestAppContext,
  17    http_client::{FakeHttpClient, Response},
  18};
  19use indoc::indoc;
  20use language::Point;
  21use lsp::LanguageServerId;
  22use parking_lot::Mutex;
  23use pretty_assertions::{assert_eq, assert_matches};
  24use project::{FakeFs, Project};
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{path::Path, sync::Arc, time::Duration};
  28use util::{path, rel_path::rel_path};
  29use uuid::Uuid;
  30use zeta_prompt::ZetaPromptInput;
  31
  32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  33
  34#[gpui::test]
  35async fn test_current_state(cx: &mut TestAppContext) {
  36    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  37    let fs = FakeFs::new(cx.executor());
  38    fs.insert_tree(
  39        "/root",
  40        json!({
  41            "1.txt": "Hello!\nHow\nBye\n",
  42            "2.txt": "Hola!\nComo\nAdios\n"
  43        }),
  44    )
  45    .await;
  46    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  47
  48    let buffer1 = project
  49        .update(cx, |project, cx| {
  50            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  51            project.set_active_path(Some(path.clone()), cx);
  52            project.open_buffer(path, cx)
  53        })
  54        .await
  55        .unwrap();
  56    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  57    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  58
  59    ep_store.update(cx, |ep_store, cx| {
  60        ep_store.register_project(&project, cx);
  61        ep_store.register_buffer(&buffer1, &project, cx);
  62    });
  63
  64    // Prediction for current file
  65
  66    ep_store.update(cx, |ep_store, cx| {
  67        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  68    });
  69    let (request, respond_tx) = requests.predict.next().await.unwrap();
  70
  71    respond_tx
  72        .send(model_response(
  73            &request,
  74            indoc! {r"
  75                --- a/root/1.txt
  76                +++ b/root/1.txt
  77                @@ ... @@
  78                 Hello!
  79                -How
  80                +How are you?
  81                 Bye
  82            "},
  83        ))
  84        .unwrap();
  85
  86    cx.run_until_parked();
  87
  88    ep_store.update(cx, |ep_store, cx| {
  89        let prediction = ep_store
  90            .prediction_at(&buffer1, None, &project, cx)
  91            .unwrap();
  92        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  93    });
  94
  95    ep_store.update(cx, |ep_store, cx| {
  96        ep_store.reject_current_prediction(
  97            EditPredictionRejectReason::Discarded,
  98            &project,
  99            edit_prediction_types::EditPredictionDismissReason::Ignored,
 100            cx,
 101        );
 102    });
 103
 104    // Prediction for diagnostic in another file
 105
 106    let diagnostic = lsp::Diagnostic {
 107        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 108        severity: Some(lsp::DiagnosticSeverity::ERROR),
 109        message: "Sentence is incomplete".to_string(),
 110        ..Default::default()
 111    };
 112
 113    project.update(cx, |project, cx| {
 114        project.lsp_store().update(cx, |lsp_store, cx| {
 115            lsp_store
 116                .update_diagnostics(
 117                    LanguageServerId(0),
 118                    lsp::PublishDiagnosticsParams {
 119                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 120                        diagnostics: vec![diagnostic],
 121                        version: None,
 122                    },
 123                    None,
 124                    language::DiagnosticSourceKind::Pushed,
 125                    &[],
 126                    cx,
 127                )
 128                .unwrap();
 129        });
 130    });
 131
 132    let (request, respond_tx) = requests.predict.next().await.unwrap();
 133    respond_tx
 134        .send(model_response(
 135            &request,
 136            indoc! {r#"
 137                --- a/root/2.txt
 138                +++ b/root/2.txt
 139                @@ ... @@
 140                 Hola!
 141                -Como
 142                +Como estas?
 143                 Adios
 144            "#},
 145        ))
 146        .unwrap();
 147    cx.run_until_parked();
 148
 149    ep_store.update(cx, |ep_store, cx| {
 150        let prediction = ep_store
 151            .prediction_at(&buffer1, None, &project, cx)
 152            .unwrap();
 153        assert_matches!(
 154            prediction,
 155            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 156        );
 157    });
 158
 159    let buffer2 = project
 160        .update(cx, |project, cx| {
 161            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 162            project.open_buffer(path, cx)
 163        })
 164        .await
 165        .unwrap();
 166
 167    ep_store.update(cx, |ep_store, cx| {
 168        let prediction = ep_store
 169            .prediction_at(&buffer2, None, &project, cx)
 170            .unwrap();
 171        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 172    });
 173}
 174
 175#[gpui::test]
 176async fn test_simple_request(cx: &mut TestAppContext) {
 177    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 178    let fs = FakeFs::new(cx.executor());
 179    fs.insert_tree(
 180        "/root",
 181        json!({
 182            "foo.md":  "Hello!\nHow\nBye\n"
 183        }),
 184    )
 185    .await;
 186    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 187
 188    let buffer = project
 189        .update(cx, |project, cx| {
 190            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 191            project.open_buffer(path, cx)
 192        })
 193        .await
 194        .unwrap();
 195    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 196    let position = snapshot.anchor_before(language::Point::new(1, 3));
 197
 198    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 199        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 200    });
 201
 202    let (request, respond_tx) = requests.predict.next().await.unwrap();
 203
 204    // TODO Put back when we have a structured request again
 205    // assert_eq!(
 206    //     request.excerpt_path.as_ref(),
 207    //     Path::new(path!("root/foo.md"))
 208    // );
 209    // assert_eq!(
 210    //     request.cursor_point,
 211    //     Point {
 212    //         line: Line(1),
 213    //         column: 3
 214    //     }
 215    // );
 216
 217    respond_tx
 218        .send(model_response(
 219            &request,
 220            indoc! { r"
 221                --- a/root/foo.md
 222                +++ b/root/foo.md
 223                @@ ... @@
 224                 Hello!
 225                -How
 226                +How are you?
 227                 Bye
 228            "},
 229        ))
 230        .unwrap();
 231
 232    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 233
 234    assert_eq!(prediction.edits.len(), 1);
 235    assert_eq!(
 236        prediction.edits[0].0.to_point(&snapshot).start,
 237        language::Point::new(1, 3)
 238    );
 239    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 240}
 241
 242#[gpui::test]
 243async fn test_request_events(cx: &mut TestAppContext) {
 244    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 245    let fs = FakeFs::new(cx.executor());
 246    fs.insert_tree(
 247        "/root",
 248        json!({
 249            "foo.md": "Hello!\n\nBye\n"
 250        }),
 251    )
 252    .await;
 253    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 254
 255    let buffer = project
 256        .update(cx, |project, cx| {
 257            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 258            project.open_buffer(path, cx)
 259        })
 260        .await
 261        .unwrap();
 262
 263    ep_store.update(cx, |ep_store, cx| {
 264        ep_store.register_buffer(&buffer, &project, cx);
 265    });
 266
 267    buffer.update(cx, |buffer, cx| {
 268        buffer.edit(vec![(7..7, "How")], None, cx);
 269    });
 270
 271    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 272    let position = snapshot.anchor_before(language::Point::new(1, 3));
 273
 274    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 275        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 276    });
 277
 278    let (request, respond_tx) = requests.predict.next().await.unwrap();
 279
 280    let prompt = prompt_from_request(&request);
 281    assert!(
 282        prompt.contains(indoc! {"
 283        --- a/root/foo.md
 284        +++ b/root/foo.md
 285        @@ -1,3 +1,3 @@
 286         Hello!
 287        -
 288        +How
 289         Bye
 290    "}),
 291        "{prompt}"
 292    );
 293
 294    respond_tx
 295        .send(model_response(
 296            &request,
 297            indoc! {r#"
 298                --- a/root/foo.md
 299                +++ b/root/foo.md
 300                @@ ... @@
 301                 Hello!
 302                -How
 303                +How are you?
 304                 Bye
 305        "#},
 306        ))
 307        .unwrap();
 308
 309    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 310
 311    assert_eq!(prediction.edits.len(), 1);
 312    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 313}
 314
 315#[gpui::test]
 316async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 317    let (ep_store, _requests) = init_test_with_fake_client(cx);
 318    let fs = FakeFs::new(cx.executor());
 319    fs.insert_tree(
 320        "/root",
 321        json!({
 322            "foo.md": "Hello!\n\nBye\n"
 323        }),
 324    )
 325    .await;
 326    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 327
 328    let buffer = project
 329        .update(cx, |project, cx| {
 330            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 331            project.open_buffer(path, cx)
 332        })
 333        .await
 334        .unwrap();
 335
 336    ep_store.update(cx, |ep_store, cx| {
 337        ep_store.register_buffer(&buffer, &project, cx);
 338    });
 339
 340    // First burst: insert "How"
 341    buffer.update(cx, |buffer, cx| {
 342        buffer.edit(vec![(7..7, "How")], None, cx);
 343    });
 344
 345    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 346    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 347    cx.run_until_parked();
 348
 349    // Second burst: append " are you?" immediately after "How" on the same line.
 350    //
 351    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 352    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 353    buffer.update(cx, |buffer, cx| {
 354        buffer.edit(vec![(10..10, " are you?")], None, cx);
 355    });
 356
 357    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 358    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 359    buffer.update(cx, |buffer, cx| {
 360        buffer.edit(vec![(19..19, "!")], None, cx);
 361    });
 362
 363    // Without time-based splitting, there is one event.
 364    let events = ep_store.update(cx, |ep_store, cx| {
 365        ep_store.edit_history_for_project(&project, cx)
 366    });
 367    assert_eq!(events.len(), 1);
 368    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 369    assert_eq!(
 370        diff.as_str(),
 371        indoc! {"
 372            @@ -1,3 +1,3 @@
 373             Hello!
 374            -
 375            +How are you?!
 376             Bye
 377        "}
 378    );
 379
 380    // With time-based splitting, there are two distinct events.
 381    let events = ep_store.update(cx, |ep_store, cx| {
 382        ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
 383    });
 384    assert_eq!(events.len(), 2);
 385    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 386    assert_eq!(
 387        diff.as_str(),
 388        indoc! {"
 389            @@ -1,3 +1,3 @@
 390             Hello!
 391            -
 392            +How
 393             Bye
 394        "}
 395    );
 396
 397    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 398    assert_eq!(
 399        diff.as_str(),
 400        indoc! {"
 401            @@ -1,3 +1,3 @@
 402             Hello!
 403            -How
 404            +How are you?!
 405             Bye
 406        "}
 407    );
 408}
 409
 410#[gpui::test]
 411async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
 412    let (ep_store, _requests) = init_test_with_fake_client(cx);
 413    let fs = FakeFs::new(cx.executor());
 414
 415    // Create a file with 30 lines to test line-based coalescing
 416    let content = (1..=30)
 417        .map(|i| format!("Line {}\n", i))
 418        .collect::<String>();
 419    fs.insert_tree(
 420        "/root",
 421        json!({
 422            "foo.md": content
 423        }),
 424    )
 425    .await;
 426    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 427
 428    let buffer = project
 429        .update(cx, |project, cx| {
 430            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 431            project.open_buffer(path, cx)
 432        })
 433        .await
 434        .unwrap();
 435
 436    ep_store.update(cx, |ep_store, cx| {
 437        ep_store.register_buffer(&buffer, &project, cx);
 438    });
 439
 440    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 441    buffer.update(cx, |buffer, cx| {
 442        let start = Point::new(10, 0).to_offset(buffer);
 443        let end = Point::new(13, 0).to_offset(buffer);
 444        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 445    });
 446
 447    let events = ep_store.update(cx, |ep_store, cx| {
 448        ep_store.edit_history_for_project(&project, cx)
 449    });
 450    assert_eq!(
 451        render_events(&events),
 452        indoc! {"
 453            @@ -8,9 +8,8 @@
 454             Line 8
 455             Line 9
 456             Line 10
 457            -Line 11
 458            -Line 12
 459            -Line 13
 460            +Middle A
 461            +Middle B
 462             Line 14
 463             Line 15
 464             Line 16
 465        "},
 466        "After first edit"
 467    );
 468
 469    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 470    // This tests that coalescing considers the START of the existing range
 471    buffer.update(cx, |buffer, cx| {
 472        let offset = Point::new(5, 0).to_offset(buffer);
 473        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 474    });
 475
 476    let events = ep_store.update(cx, |ep_store, cx| {
 477        ep_store.edit_history_for_project(&project, cx)
 478    });
 479    assert_eq!(
 480        render_events(&events),
 481        indoc! {"
 482            @@ -3,14 +3,14 @@
 483             Line 3
 484             Line 4
 485             Line 5
 486            +Above
 487             Line 6
 488             Line 7
 489             Line 8
 490             Line 9
 491             Line 10
 492            -Line 11
 493            -Line 12
 494            -Line 13
 495            +Middle A
 496            +Middle B
 497             Line 14
 498             Line 15
 499             Line 16
 500        "},
 501        "After inserting above (should coalesce)"
 502    );
 503
 504    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 505    // This tests that coalescing considers the END of the existing range
 506    buffer.update(cx, |buffer, cx| {
 507        let offset = Point::new(14, 0).to_offset(buffer);
 508        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 509    });
 510
 511    let events = ep_store.update(cx, |ep_store, cx| {
 512        ep_store.edit_history_for_project(&project, cx)
 513    });
 514    assert_eq!(
 515        render_events(&events),
 516        indoc! {"
 517            @@ -3,15 +3,16 @@
 518             Line 3
 519             Line 4
 520             Line 5
 521            +Above
 522             Line 6
 523             Line 7
 524             Line 8
 525             Line 9
 526             Line 10
 527            -Line 11
 528            -Line 12
 529            -Line 13
 530            +Middle A
 531            +Middle B
 532             Line 14
 533            +Below
 534             Line 15
 535             Line 16
 536             Line 17
 537        "},
 538        "After inserting below (should coalesce)"
 539    );
 540
 541    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 542    // This should NOT coalesce - creates a new event
 543    buffer.update(cx, |buffer, cx| {
 544        let offset = Point::new(25, 0).to_offset(buffer);
 545        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 546    });
 547
 548    let events = ep_store.update(cx, |ep_store, cx| {
 549        ep_store.edit_history_for_project(&project, cx)
 550    });
 551    assert_eq!(
 552        render_events(&events),
 553        indoc! {"
 554            @@ -3,15 +3,16 @@
 555             Line 3
 556             Line 4
 557             Line 5
 558            +Above
 559             Line 6
 560             Line 7
 561             Line 8
 562             Line 9
 563             Line 10
 564            -Line 11
 565            -Line 12
 566            -Line 13
 567            +Middle A
 568            +Middle B
 569             Line 14
 570            +Below
 571             Line 15
 572             Line 16
 573             Line 17
 574
 575            ---
 576            @@ -23,6 +23,7 @@
 577             Line 22
 578             Line 23
 579             Line 24
 580            +Far below
 581             Line 25
 582             Line 26
 583             Line 27
 584        "},
 585        "After inserting far below (should NOT coalesce)"
 586    );
 587}
 588
 589fn render_events(events: &[StoredEvent]) -> String {
 590    events
 591        .iter()
 592        .map(|e| {
 593            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 594            diff.as_str()
 595        })
 596        .collect::<Vec<_>>()
 597        .join("\n---\n")
 598}
 599
 600#[gpui::test]
 601async fn test_empty_prediction(cx: &mut TestAppContext) {
 602    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 603    let fs = FakeFs::new(cx.executor());
 604    fs.insert_tree(
 605        "/root",
 606        json!({
 607            "foo.md":  "Hello!\nHow\nBye\n"
 608        }),
 609    )
 610    .await;
 611    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 612
 613    let buffer = project
 614        .update(cx, |project, cx| {
 615            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 616            project.open_buffer(path, cx)
 617        })
 618        .await
 619        .unwrap();
 620    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 621    let position = snapshot.anchor_before(language::Point::new(1, 3));
 622
 623    ep_store.update(cx, |ep_store, cx| {
 624        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 625    });
 626
 627    let (request, respond_tx) = requests.predict.next().await.unwrap();
 628    let response = model_response(&request, "");
 629    let id = response.request_id.clone();
 630    respond_tx.send(response).unwrap();
 631
 632    cx.run_until_parked();
 633
 634    ep_store.update(cx, |ep_store, cx| {
 635        assert!(
 636            ep_store
 637                .prediction_at(&buffer, None, &project, cx)
 638                .is_none()
 639        );
 640    });
 641
 642    // prediction is reported as rejected
 643    let (reject_request, _) = requests.reject.next().await.unwrap();
 644
 645    assert_eq!(
 646        &reject_request.rejections,
 647        &[EditPredictionRejection {
 648            request_id: id,
 649            reason: EditPredictionRejectReason::Empty,
 650            was_shown: false
 651        }]
 652    );
 653}
 654
 655#[gpui::test]
 656async fn test_interpolated_empty(cx: &mut TestAppContext) {
 657    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 658    let fs = FakeFs::new(cx.executor());
 659    fs.insert_tree(
 660        "/root",
 661        json!({
 662            "foo.md":  "Hello!\nHow\nBye\n"
 663        }),
 664    )
 665    .await;
 666    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 667
 668    let buffer = project
 669        .update(cx, |project, cx| {
 670            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 671            project.open_buffer(path, cx)
 672        })
 673        .await
 674        .unwrap();
 675    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 676    let position = snapshot.anchor_before(language::Point::new(1, 3));
 677
 678    ep_store.update(cx, |ep_store, cx| {
 679        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 680    });
 681
 682    let (request, respond_tx) = requests.predict.next().await.unwrap();
 683
 684    buffer.update(cx, |buffer, cx| {
 685        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 686    });
 687
 688    let response = model_response(&request, SIMPLE_DIFF);
 689    let id = response.request_id.clone();
 690    respond_tx.send(response).unwrap();
 691
 692    cx.run_until_parked();
 693
 694    ep_store.update(cx, |ep_store, cx| {
 695        assert!(
 696            ep_store
 697                .prediction_at(&buffer, None, &project, cx)
 698                .is_none()
 699        );
 700    });
 701
 702    // prediction is reported as rejected
 703    let (reject_request, _) = requests.reject.next().await.unwrap();
 704
 705    assert_eq!(
 706        &reject_request.rejections,
 707        &[EditPredictionRejection {
 708            request_id: id,
 709            reason: EditPredictionRejectReason::InterpolatedEmpty,
 710            was_shown: false
 711        }]
 712    );
 713}
 714
 715const SIMPLE_DIFF: &str = indoc! { r"
 716    --- a/root/foo.md
 717    +++ b/root/foo.md
 718    @@ ... @@
 719     Hello!
 720    -How
 721    +How are you?
 722     Bye
 723"};
 724
 725#[gpui::test]
 726async fn test_replace_current(cx: &mut TestAppContext) {
 727    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 728    let fs = FakeFs::new(cx.executor());
 729    fs.insert_tree(
 730        "/root",
 731        json!({
 732            "foo.md":  "Hello!\nHow\nBye\n"
 733        }),
 734    )
 735    .await;
 736    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 737
 738    let buffer = project
 739        .update(cx, |project, cx| {
 740            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 741            project.open_buffer(path, cx)
 742        })
 743        .await
 744        .unwrap();
 745    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 746    let position = snapshot.anchor_before(language::Point::new(1, 3));
 747
 748    ep_store.update(cx, |ep_store, cx| {
 749        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 750    });
 751
 752    let (request, respond_tx) = requests.predict.next().await.unwrap();
 753    let first_response = model_response(&request, SIMPLE_DIFF);
 754    let first_id = first_response.request_id.clone();
 755    respond_tx.send(first_response).unwrap();
 756
 757    cx.run_until_parked();
 758
 759    ep_store.update(cx, |ep_store, cx| {
 760        assert_eq!(
 761            ep_store
 762                .prediction_at(&buffer, None, &project, cx)
 763                .unwrap()
 764                .id
 765                .0,
 766            first_id
 767        );
 768    });
 769
 770    // a second request is triggered
 771    ep_store.update(cx, |ep_store, cx| {
 772        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 773    });
 774
 775    let (request, respond_tx) = requests.predict.next().await.unwrap();
 776    let second_response = model_response(&request, SIMPLE_DIFF);
 777    let second_id = second_response.request_id.clone();
 778    respond_tx.send(second_response).unwrap();
 779
 780    cx.run_until_parked();
 781
 782    ep_store.update(cx, |ep_store, cx| {
 783        // second replaces first
 784        assert_eq!(
 785            ep_store
 786                .prediction_at(&buffer, None, &project, cx)
 787                .unwrap()
 788                .id
 789                .0,
 790            second_id
 791        );
 792    });
 793
 794    // first is reported as replaced
 795    let (reject_request, _) = requests.reject.next().await.unwrap();
 796
 797    assert_eq!(
 798        &reject_request.rejections,
 799        &[EditPredictionRejection {
 800            request_id: first_id,
 801            reason: EditPredictionRejectReason::Replaced,
 802            was_shown: false
 803        }]
 804    );
 805}
 806
 807#[gpui::test]
 808async fn test_current_preferred(cx: &mut TestAppContext) {
 809    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 810    let fs = FakeFs::new(cx.executor());
 811    fs.insert_tree(
 812        "/root",
 813        json!({
 814            "foo.md":  "Hello!\nHow\nBye\n"
 815        }),
 816    )
 817    .await;
 818    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 819
 820    let buffer = project
 821        .update(cx, |project, cx| {
 822            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 823            project.open_buffer(path, cx)
 824        })
 825        .await
 826        .unwrap();
 827    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 828    let position = snapshot.anchor_before(language::Point::new(1, 3));
 829
 830    ep_store.update(cx, |ep_store, cx| {
 831        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 832    });
 833
 834    let (request, respond_tx) = requests.predict.next().await.unwrap();
 835    let first_response = model_response(&request, SIMPLE_DIFF);
 836    let first_id = first_response.request_id.clone();
 837    respond_tx.send(first_response).unwrap();
 838
 839    cx.run_until_parked();
 840
 841    ep_store.update(cx, |ep_store, cx| {
 842        assert_eq!(
 843            ep_store
 844                .prediction_at(&buffer, None, &project, cx)
 845                .unwrap()
 846                .id
 847                .0,
 848            first_id
 849        );
 850    });
 851
 852    // a second request is triggered
 853    ep_store.update(cx, |ep_store, cx| {
 854        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 855    });
 856
 857    let (request, respond_tx) = requests.predict.next().await.unwrap();
 858    // worse than current prediction
 859    let second_response = model_response(
 860        &request,
 861        indoc! { r"
 862            --- a/root/foo.md
 863            +++ b/root/foo.md
 864            @@ ... @@
 865             Hello!
 866            -How
 867            +How are
 868             Bye
 869        "},
 870    );
 871    let second_id = second_response.request_id.clone();
 872    respond_tx.send(second_response).unwrap();
 873
 874    cx.run_until_parked();
 875
 876    ep_store.update(cx, |ep_store, cx| {
 877        // first is preferred over second
 878        assert_eq!(
 879            ep_store
 880                .prediction_at(&buffer, None, &project, cx)
 881                .unwrap()
 882                .id
 883                .0,
 884            first_id
 885        );
 886    });
 887
 888    // second is reported as rejected
 889    let (reject_request, _) = requests.reject.next().await.unwrap();
 890
 891    assert_eq!(
 892        &reject_request.rejections,
 893        &[EditPredictionRejection {
 894            request_id: second_id,
 895            reason: EditPredictionRejectReason::CurrentPreferred,
 896            was_shown: false
 897        }]
 898    );
 899}
 900
 901#[gpui::test]
 902async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
 903    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 904    let fs = FakeFs::new(cx.executor());
 905    fs.insert_tree(
 906        "/root",
 907        json!({
 908            "foo.md":  "Hello!\nHow\nBye\n"
 909        }),
 910    )
 911    .await;
 912    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 913
 914    let buffer = project
 915        .update(cx, |project, cx| {
 916            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 917            project.open_buffer(path, cx)
 918        })
 919        .await
 920        .unwrap();
 921    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 922    let position = snapshot.anchor_before(language::Point::new(1, 3));
 923
 924    // start two refresh tasks
 925    ep_store.update(cx, |ep_store, cx| {
 926        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 927    });
 928
 929    let (request1, respond_first) = requests.predict.next().await.unwrap();
 930
 931    ep_store.update(cx, |ep_store, cx| {
 932        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 933    });
 934
 935    let (request, respond_second) = requests.predict.next().await.unwrap();
 936
 937    // wait for throttle
 938    cx.run_until_parked();
 939
 940    // second responds first
 941    let second_response = model_response(&request, SIMPLE_DIFF);
 942    let second_id = second_response.request_id.clone();
 943    respond_second.send(second_response).unwrap();
 944
 945    cx.run_until_parked();
 946
 947    ep_store.update(cx, |ep_store, cx| {
 948        // current prediction is second
 949        assert_eq!(
 950            ep_store
 951                .prediction_at(&buffer, None, &project, cx)
 952                .unwrap()
 953                .id
 954                .0,
 955            second_id
 956        );
 957    });
 958
 959    let first_response = model_response(&request1, SIMPLE_DIFF);
 960    let first_id = first_response.request_id.clone();
 961    respond_first.send(first_response).unwrap();
 962
 963    cx.run_until_parked();
 964
 965    ep_store.update(cx, |ep_store, cx| {
 966        // current prediction is still second, since first was cancelled
 967        assert_eq!(
 968            ep_store
 969                .prediction_at(&buffer, None, &project, cx)
 970                .unwrap()
 971                .id
 972                .0,
 973            second_id
 974        );
 975    });
 976
 977    // first is reported as rejected
 978    let (reject_request, _) = requests.reject.next().await.unwrap();
 979
 980    cx.run_until_parked();
 981
 982    assert_eq!(
 983        &reject_request.rejections,
 984        &[EditPredictionRejection {
 985            request_id: first_id,
 986            reason: EditPredictionRejectReason::Canceled,
 987            was_shown: false
 988        }]
 989    );
 990}
 991
 992#[gpui::test]
 993async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
 994    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 995    let fs = FakeFs::new(cx.executor());
 996    fs.insert_tree(
 997        "/root",
 998        json!({
 999            "foo.md":  "Hello!\nHow\nBye\n"
1000        }),
1001    )
1002    .await;
1003    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1004
1005    let buffer = project
1006        .update(cx, |project, cx| {
1007            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1008            project.open_buffer(path, cx)
1009        })
1010        .await
1011        .unwrap();
1012    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1013    let position = snapshot.anchor_before(language::Point::new(1, 3));
1014
1015    // start two refresh tasks
1016    ep_store.update(cx, |ep_store, cx| {
1017        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1018    });
1019
1020    let (request1, respond_first) = requests.predict.next().await.unwrap();
1021
1022    ep_store.update(cx, |ep_store, cx| {
1023        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1024    });
1025
1026    let (request2, respond_second) = requests.predict.next().await.unwrap();
1027
1028    // wait for throttle, so requests are sent
1029    cx.run_until_parked();
1030
1031    ep_store.update(cx, |ep_store, cx| {
1032        // start a third request
1033        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1034
1035        // 2 are pending, so 2nd is cancelled
1036        assert_eq!(
1037            ep_store
1038                .get_or_init_project(&project, cx)
1039                .cancelled_predictions
1040                .iter()
1041                .copied()
1042                .collect::<Vec<_>>(),
1043            [1]
1044        );
1045    });
1046
1047    // wait for throttle
1048    cx.run_until_parked();
1049
1050    let (request3, respond_third) = requests.predict.next().await.unwrap();
1051
1052    let first_response = model_response(&request1, SIMPLE_DIFF);
1053    let first_id = first_response.request_id.clone();
1054    respond_first.send(first_response).unwrap();
1055
1056    cx.run_until_parked();
1057
1058    ep_store.update(cx, |ep_store, cx| {
1059        // current prediction is first
1060        assert_eq!(
1061            ep_store
1062                .prediction_at(&buffer, None, &project, cx)
1063                .unwrap()
1064                .id
1065                .0,
1066            first_id
1067        );
1068    });
1069
1070    let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1071    let cancelled_id = cancelled_response.request_id.clone();
1072    respond_second.send(cancelled_response).unwrap();
1073
1074    cx.run_until_parked();
1075
1076    ep_store.update(cx, |ep_store, cx| {
1077        // current prediction is still first, since second was cancelled
1078        assert_eq!(
1079            ep_store
1080                .prediction_at(&buffer, None, &project, cx)
1081                .unwrap()
1082                .id
1083                .0,
1084            first_id
1085        );
1086    });
1087
1088    let third_response = model_response(&request3, SIMPLE_DIFF);
1089    let third_response_id = third_response.request_id.clone();
1090    respond_third.send(third_response).unwrap();
1091
1092    cx.run_until_parked();
1093
1094    ep_store.update(cx, |ep_store, cx| {
1095        // third completes and replaces first
1096        assert_eq!(
1097            ep_store
1098                .prediction_at(&buffer, None, &project, cx)
1099                .unwrap()
1100                .id
1101                .0,
1102            third_response_id
1103        );
1104    });
1105
1106    // second is reported as rejected
1107    let (reject_request, _) = requests.reject.next().await.unwrap();
1108
1109    cx.run_until_parked();
1110
1111    assert_eq!(
1112        &reject_request.rejections,
1113        &[
1114            EditPredictionRejection {
1115                request_id: cancelled_id,
1116                reason: EditPredictionRejectReason::Canceled,
1117                was_shown: false
1118            },
1119            EditPredictionRejection {
1120                request_id: first_id,
1121                reason: EditPredictionRejectReason::Replaced,
1122                was_shown: false
1123            }
1124        ]
1125    );
1126}
1127
1128#[gpui::test]
1129async fn test_rejections_flushing(cx: &mut TestAppContext) {
1130    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1131
1132    ep_store.update(cx, |ep_store, cx| {
1133        ep_store.reject_prediction(
1134            EditPredictionId("test-1".into()),
1135            EditPredictionRejectReason::Discarded,
1136            false,
1137            edit_prediction_types::EditPredictionDismissReason::Ignored,
1138            cx,
1139        );
1140        ep_store.reject_prediction(
1141            EditPredictionId("test-2".into()),
1142            EditPredictionRejectReason::Canceled,
1143            true,
1144            edit_prediction_types::EditPredictionDismissReason::Ignored,
1145            cx,
1146        );
1147    });
1148
1149    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1150    cx.run_until_parked();
1151
1152    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1153    respond_tx.send(()).unwrap();
1154
1155    // batched
1156    assert_eq!(reject_request.rejections.len(), 2);
1157    assert_eq!(
1158        reject_request.rejections[0],
1159        EditPredictionRejection {
1160            request_id: "test-1".to_string(),
1161            reason: EditPredictionRejectReason::Discarded,
1162            was_shown: false
1163        }
1164    );
1165    assert_eq!(
1166        reject_request.rejections[1],
1167        EditPredictionRejection {
1168            request_id: "test-2".to_string(),
1169            reason: EditPredictionRejectReason::Canceled,
1170            was_shown: true
1171        }
1172    );
1173
1174    // Reaching batch size limit sends without debounce
1175    ep_store.update(cx, |ep_store, cx| {
1176        for i in 0..70 {
1177            ep_store.reject_prediction(
1178                EditPredictionId(format!("batch-{}", i).into()),
1179                EditPredictionRejectReason::Discarded,
1180                false,
1181                edit_prediction_types::EditPredictionDismissReason::Ignored,
1182                cx,
1183            );
1184        }
1185    });
1186
1187    // First MAX/2 items are sent immediately
1188    cx.run_until_parked();
1189    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1190    respond_tx.send(()).unwrap();
1191
1192    assert_eq!(reject_request.rejections.len(), 50);
1193    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1194    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1195
1196    // Remaining items are debounced with the next batch
1197    cx.executor().advance_clock(Duration::from_secs(15));
1198    cx.run_until_parked();
1199
1200    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1201    respond_tx.send(()).unwrap();
1202
1203    assert_eq!(reject_request.rejections.len(), 20);
1204    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1205    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1206
1207    // Request failure
1208    ep_store.update(cx, |ep_store, cx| {
1209        ep_store.reject_prediction(
1210            EditPredictionId("retry-1".into()),
1211            EditPredictionRejectReason::Discarded,
1212            false,
1213            edit_prediction_types::EditPredictionDismissReason::Ignored,
1214            cx,
1215        );
1216    });
1217
1218    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1219    cx.run_until_parked();
1220
1221    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1222    assert_eq!(reject_request.rejections.len(), 1);
1223    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1224    // Simulate failure
1225    drop(_respond_tx);
1226
1227    // Add another rejection
1228    ep_store.update(cx, |ep_store, cx| {
1229        ep_store.reject_prediction(
1230            EditPredictionId("retry-2".into()),
1231            EditPredictionRejectReason::Discarded,
1232            false,
1233            edit_prediction_types::EditPredictionDismissReason::Ignored,
1234            cx,
1235        );
1236    });
1237
1238    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1239    cx.run_until_parked();
1240
1241    // Retry should include both the failed item and the new one
1242    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1243    respond_tx.send(()).unwrap();
1244
1245    assert_eq!(reject_request.rejections.len(), 2);
1246    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1247    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1248}
1249
1250// Skipped until we start including diagnostics in prompt
1251// #[gpui::test]
1252// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1253//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1254//     let fs = FakeFs::new(cx.executor());
1255//     fs.insert_tree(
1256//         "/root",
1257//         json!({
1258//             "foo.md": "Hello!\nBye"
1259//         }),
1260//     )
1261//     .await;
1262//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1263
1264//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1265//     let diagnostic = lsp::Diagnostic {
1266//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1267//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1268//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1269//         ..Default::default()
1270//     };
1271
1272//     project.update(cx, |project, cx| {
1273//         project.lsp_store().update(cx, |lsp_store, cx| {
1274//             // Create some diagnostics
1275//             lsp_store
1276//                 .update_diagnostics(
1277//                     LanguageServerId(0),
1278//                     lsp::PublishDiagnosticsParams {
1279//                         uri: path_to_buffer_uri.clone(),
1280//                         diagnostics: vec![diagnostic],
1281//                         version: None,
1282//                     },
1283//                     None,
1284//                     language::DiagnosticSourceKind::Pushed,
1285//                     &[],
1286//                     cx,
1287//                 )
1288//                 .unwrap();
1289//         });
1290//     });
1291
1292//     let buffer = project
1293//         .update(cx, |project, cx| {
1294//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1295//             project.open_buffer(path, cx)
1296//         })
1297//         .await
1298//         .unwrap();
1299
1300//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1301//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1302
1303//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1304//         ep_store.request_prediction(&project, &buffer, position, cx)
1305//     });
1306
1307//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1308
1309//     assert_eq!(request.diagnostic_groups.len(), 1);
1310//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1311//         .unwrap();
1312//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1313//     assert_eq!(
1314//         value,
1315//         json!({
1316//             "entries": [{
1317//                 "range": {
1318//                     "start": 8,
1319//                     "end": 10
1320//                 },
1321//                 "diagnostic": {
1322//                     "source": null,
1323//                     "code": null,
1324//                     "code_description": null,
1325//                     "severity": 1,
1326//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1327//                     "markdown": null,
1328//                     "group_id": 0,
1329//                     "is_primary": true,
1330//                     "is_disk_based": false,
1331//                     "is_unnecessary": false,
1332//                     "source_kind": "Pushed",
1333//                     "data": null,
1334//                     "underline": true
1335//                 }
1336//             }],
1337//             "primary_ix": 0
1338//         })
1339//     );
1340// }
1341
1342// Generate a model response that would apply the given diff to the active file.
1343fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1344    let excerpt =
1345        request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1346    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1347
1348    PredictEditsV3Response {
1349        request_id: Uuid::new_v4().to_string(),
1350        output: new_excerpt,
1351    }
1352}
1353
1354fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1355    zeta_prompt::format_zeta_prompt(&request.input, request.prompt_version)
1356}
1357
1358struct RequestChannels {
1359    predict: mpsc::UnboundedReceiver<(
1360        PredictEditsV3Request,
1361        oneshot::Sender<PredictEditsV3Response>,
1362    )>,
1363    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1364}
1365
1366fn init_test_with_fake_client(
1367    cx: &mut TestAppContext,
1368) -> (Entity<EditPredictionStore>, RequestChannels) {
1369    cx.update(move |cx| {
1370        let settings_store = SettingsStore::test(cx);
1371        cx.set_global(settings_store);
1372        zlog::init_test();
1373
1374        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1375        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1376
1377        let http_client = FakeHttpClient::create({
1378            move |req| {
1379                let uri = req.uri().path().to_string();
1380                let mut body = req.into_body();
1381                let predict_req_tx = predict_req_tx.clone();
1382                let reject_req_tx = reject_req_tx.clone();
1383                async move {
1384                    let resp = match uri.as_str() {
1385                        "/client/llm_tokens" => serde_json::to_string(&json!({
1386                            "token": "test"
1387                        }))
1388                        .unwrap(),
1389                        "/predict_edits/v3" => {
1390                            let mut buf = Vec::new();
1391                            body.read_to_end(&mut buf).await.ok();
1392                            let req = serde_json::from_slice(&buf).unwrap();
1393
1394                            let (res_tx, res_rx) = oneshot::channel();
1395                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1396                            serde_json::to_string(&res_rx.await?).unwrap()
1397                        }
1398                        "/predict_edits/reject" => {
1399                            let mut buf = Vec::new();
1400                            body.read_to_end(&mut buf).await.ok();
1401                            let req = serde_json::from_slice(&buf).unwrap();
1402
1403                            let (res_tx, res_rx) = oneshot::channel();
1404                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1405                            serde_json::to_string(&res_rx.await?).unwrap()
1406                        }
1407                        _ => {
1408                            panic!("Unexpected path: {}", uri)
1409                        }
1410                    };
1411
1412                    Ok(Response::builder().body(resp.into()).unwrap())
1413                }
1414            }
1415        });
1416
1417        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1418        client.cloud_client().set_credentials(1, "test".into());
1419
1420        language_model::init(client.clone(), cx);
1421
1422        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1423        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1424
1425        (
1426            ep_store,
1427            RequestChannels {
1428                predict: predict_req_rx,
1429                reject: reject_req_rx,
1430            },
1431        )
1432    })
1433}
1434
1435const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1436
1437#[gpui::test]
1438async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1439    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1440    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1441        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1442    });
1443
1444    let edit_preview = cx
1445        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1446        .await;
1447
1448    let prediction = EditPrediction {
1449        edits,
1450        edit_preview,
1451        buffer: buffer.clone(),
1452        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1453        id: EditPredictionId("the-id".into()),
1454        inputs: ZetaPromptInput {
1455            events: Default::default(),
1456            related_files: Default::default(),
1457            cursor_path: Path::new("").into(),
1458            cursor_excerpt: "".into(),
1459            editable_range_in_excerpt: 0..0,
1460            cursor_offset_in_excerpt: 0,
1461        },
1462        buffer_snapshotted_at: Instant::now(),
1463        response_received_at: Instant::now(),
1464    };
1465
1466    cx.update(|cx| {
1467        assert_eq!(
1468            from_completion_edits(
1469                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1470                &buffer,
1471                cx
1472            ),
1473            vec![(2..5, "REM".into()), (9..11, "".into())]
1474        );
1475
1476        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1477        assert_eq!(
1478            from_completion_edits(
1479                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1480                &buffer,
1481                cx
1482            ),
1483            vec![(2..2, "REM".into()), (6..8, "".into())]
1484        );
1485
1486        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1487        assert_eq!(
1488            from_completion_edits(
1489                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1490                &buffer,
1491                cx
1492            ),
1493            vec![(2..5, "REM".into()), (9..11, "".into())]
1494        );
1495
1496        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1497        assert_eq!(
1498            from_completion_edits(
1499                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1500                &buffer,
1501                cx
1502            ),
1503            vec![(3..3, "EM".into()), (7..9, "".into())]
1504        );
1505
1506        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1507        assert_eq!(
1508            from_completion_edits(
1509                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1510                &buffer,
1511                cx
1512            ),
1513            vec![(4..4, "M".into()), (8..10, "".into())]
1514        );
1515
1516        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1517        assert_eq!(
1518            from_completion_edits(
1519                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1520                &buffer,
1521                cx
1522            ),
1523            vec![(9..11, "".into())]
1524        );
1525
1526        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1527        assert_eq!(
1528            from_completion_edits(
1529                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1530                &buffer,
1531                cx
1532            ),
1533            vec![(4..4, "M".into()), (8..10, "".into())]
1534        );
1535
1536        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1537        assert_eq!(
1538            from_completion_edits(
1539                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1540                &buffer,
1541                cx
1542            ),
1543            vec![(4..4, "M".into())]
1544        );
1545
1546        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1547        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1548    })
1549}
1550
1551#[gpui::test]
1552async fn test_clean_up_diff(cx: &mut TestAppContext) {
1553    init_test(cx);
1554
1555    assert_eq!(
1556        apply_edit_prediction(
1557            indoc! {"
1558                    fn main() {
1559                        let word_1 = \"lorem\";
1560                        let range = word.len()..word.len();
1561                    }
1562                "},
1563            indoc! {"
1564                    <|editable_region_start|>
1565                    fn main() {
1566                        let word_1 = \"lorem\";
1567                        let range = word_1.len()..word_1.len();
1568                    }
1569
1570                    <|editable_region_end|>
1571                "},
1572            cx,
1573        )
1574        .await,
1575        indoc! {"
1576                fn main() {
1577                    let word_1 = \"lorem\";
1578                    let range = word_1.len()..word_1.len();
1579                }
1580            "},
1581    );
1582
1583    assert_eq!(
1584        apply_edit_prediction(
1585            indoc! {"
1586                    fn main() {
1587                        let story = \"the quick\"
1588                    }
1589                "},
1590            indoc! {"
1591                    <|editable_region_start|>
1592                    fn main() {
1593                        let story = \"the quick brown fox jumps over the lazy dog\";
1594                    }
1595
1596                    <|editable_region_end|>
1597                "},
1598            cx,
1599        )
1600        .await,
1601        indoc! {"
1602                fn main() {
1603                    let story = \"the quick brown fox jumps over the lazy dog\";
1604                }
1605            "},
1606    );
1607}
1608
1609#[gpui::test]
1610async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1611    init_test(cx);
1612
1613    let buffer_content = "lorem\n";
1614    let completion_response = indoc! {"
1615            ```animals.js
1616            <|start_of_file|>
1617            <|editable_region_start|>
1618            lorem
1619            ipsum
1620            <|editable_region_end|>
1621            ```"};
1622
1623    assert_eq!(
1624        apply_edit_prediction(buffer_content, completion_response, cx).await,
1625        "lorem\nipsum"
1626    );
1627}
1628
1629#[gpui::test]
1630async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1631    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1632    // When the buffer ends without a trailing newline, but the model returns output
1633    // with a trailing newline, zeta2 should normalize both sides before diffing
1634    // so no spurious newline is inserted.
1635    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1636    let fs = FakeFs::new(cx.executor());
1637
1638    // Single line buffer with no trailing newline
1639    fs.insert_tree(
1640        "/root",
1641        json!({
1642            "foo.txt": "hello"
1643        }),
1644    )
1645    .await;
1646    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1647
1648    let buffer = project
1649        .update(cx, |project, cx| {
1650            let path = project
1651                .find_project_path(path!("root/foo.txt"), cx)
1652                .unwrap();
1653            project.open_buffer(path, cx)
1654        })
1655        .await
1656        .unwrap();
1657
1658    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1659    let position = snapshot.anchor_before(language::Point::new(0, 5));
1660
1661    ep_store.update(cx, |ep_store, cx| {
1662        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1663    });
1664
1665    let (_request, respond_tx) = requests.predict.next().await.unwrap();
1666
1667    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1668    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1669    let response = PredictEditsV3Response {
1670        request_id: Uuid::new_v4().to_string(),
1671        output: "hello world\n".to_string(),
1672    };
1673    respond_tx.send(response).unwrap();
1674
1675    cx.run_until_parked();
1676
1677    // The prediction should insert " world" without adding a newline
1678    ep_store.update(cx, |ep_store, cx| {
1679        let prediction = ep_store
1680            .prediction_at(&buffer, None, &project, cx)
1681            .expect("should have prediction");
1682        let edits: Vec<_> = prediction
1683            .edits
1684            .iter()
1685            .map(|(range, text)| {
1686                let snapshot = buffer.read(cx).snapshot();
1687                (range.to_offset(&snapshot), text.clone())
1688            })
1689            .collect();
1690        assert_eq!(edits, vec![(5..5, " world".into())]);
1691    });
1692}
1693
1694#[gpui::test]
1695async fn test_can_collect_data(cx: &mut TestAppContext) {
1696    init_test(cx);
1697
1698    let fs = project::FakeFs::new(cx.executor());
1699    fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1700        .await;
1701
1702    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1703    let buffer = project
1704        .update(cx, |project, cx| {
1705            project.open_local_buffer(path!("/project/src/main.rs"), cx)
1706        })
1707        .await
1708        .unwrap();
1709
1710    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1711    ep_store.update(cx, |ep_store, _cx| {
1712        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1713    });
1714
1715    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1716    assert_eq!(
1717        captured_request.lock().clone().unwrap().can_collect_data,
1718        true
1719    );
1720
1721    ep_store.update(cx, |ep_store, _cx| {
1722        ep_store.data_collection_choice = DataCollectionChoice::Disabled
1723    });
1724
1725    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1726    assert_eq!(
1727        captured_request.lock().clone().unwrap().can_collect_data,
1728        false
1729    );
1730}
1731
1732#[gpui::test]
1733async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1734    init_test(cx);
1735
1736    let fs = project::FakeFs::new(cx.executor());
1737    let project = Project::test(fs.clone(), [], cx).await;
1738
1739    let buffer = cx.new(|_cx| {
1740        Buffer::remote(
1741            language::BufferId::new(1).unwrap(),
1742            ReplicaId::new(1),
1743            language::Capability::ReadWrite,
1744            "fn main() {\n    println!(\"Hello\");\n}",
1745        )
1746    });
1747
1748    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1749    ep_store.update(cx, |ep_store, _cx| {
1750        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1751    });
1752
1753    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1754    assert_eq!(
1755        captured_request.lock().clone().unwrap().can_collect_data,
1756        false
1757    );
1758}
1759
1760#[gpui::test]
1761async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1762    init_test(cx);
1763
1764    let fs = project::FakeFs::new(cx.executor());
1765    fs.insert_tree(
1766        path!("/project"),
1767        json!({
1768            "LICENSE": BSD_0_TXT,
1769            ".env": "SECRET_KEY=secret"
1770        }),
1771    )
1772    .await;
1773
1774    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1775    let buffer = project
1776        .update(cx, |project, cx| {
1777            project.open_local_buffer("/project/.env", cx)
1778        })
1779        .await
1780        .unwrap();
1781
1782    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1783    ep_store.update(cx, |ep_store, _cx| {
1784        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1785    });
1786
1787    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1788    assert_eq!(
1789        captured_request.lock().clone().unwrap().can_collect_data,
1790        false
1791    );
1792}
1793
1794#[gpui::test]
1795async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1796    init_test(cx);
1797
1798    let fs = project::FakeFs::new(cx.executor());
1799    let project = Project::test(fs.clone(), [], cx).await;
1800    let buffer = cx.new(|cx| Buffer::local("", cx));
1801
1802    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1803    ep_store.update(cx, |ep_store, _cx| {
1804        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1805    });
1806
1807    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1808    assert_eq!(
1809        captured_request.lock().clone().unwrap().can_collect_data,
1810        false
1811    );
1812}
1813
1814#[gpui::test]
1815async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1816    init_test(cx);
1817
1818    let fs = project::FakeFs::new(cx.executor());
1819    fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1820        .await;
1821
1822    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1823    let buffer = project
1824        .update(cx, |project, cx| {
1825            project.open_local_buffer("/project/main.rs", cx)
1826        })
1827        .await
1828        .unwrap();
1829
1830    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1831    ep_store.update(cx, |ep_store, _cx| {
1832        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1833    });
1834
1835    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1836    assert_eq!(
1837        captured_request.lock().clone().unwrap().can_collect_data,
1838        false
1839    );
1840}
1841
1842#[gpui::test]
1843async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1844    init_test(cx);
1845
1846    let fs = project::FakeFs::new(cx.executor());
1847    fs.insert_tree(
1848        path!("/open_source_worktree"),
1849        json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1850    )
1851    .await;
1852    fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1853        .await;
1854
1855    let project = Project::test(
1856        fs.clone(),
1857        [
1858            path!("/open_source_worktree").as_ref(),
1859            path!("/closed_source_worktree").as_ref(),
1860        ],
1861        cx,
1862    )
1863    .await;
1864    let buffer = project
1865        .update(cx, |project, cx| {
1866            project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1867        })
1868        .await
1869        .unwrap();
1870
1871    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1872    ep_store.update(cx, |ep_store, _cx| {
1873        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1874    });
1875
1876    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1877    assert_eq!(
1878        captured_request.lock().clone().unwrap().can_collect_data,
1879        true
1880    );
1881
1882    let closed_source_file = project
1883        .update(cx, |project, cx| {
1884            let worktree2 = project
1885                .worktree_for_root_name("closed_source_worktree", cx)
1886                .unwrap();
1887            worktree2.update(cx, |worktree2, cx| {
1888                worktree2.load_file(rel_path("main.rs"), cx)
1889            })
1890        })
1891        .await
1892        .unwrap()
1893        .file;
1894
1895    buffer.update(cx, |buffer, cx| {
1896        buffer.file_updated(closed_source_file, cx);
1897    });
1898
1899    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1900    assert_eq!(
1901        captured_request.lock().clone().unwrap().can_collect_data,
1902        false
1903    );
1904}
1905
1906#[gpui::test]
1907async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1908    init_test(cx);
1909
1910    let fs = project::FakeFs::new(cx.executor());
1911    fs.insert_tree(
1912        path!("/worktree1"),
1913        json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1914    )
1915    .await;
1916    fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1917        .await;
1918
1919    let project = Project::test(
1920        fs.clone(),
1921        [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1922        cx,
1923    )
1924    .await;
1925    let buffer = project
1926        .update(cx, |project, cx| {
1927            project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1928        })
1929        .await
1930        .unwrap();
1931    let private_buffer = project
1932        .update(cx, |project, cx| {
1933            project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1934        })
1935        .await
1936        .unwrap();
1937
1938    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1939    ep_store.update(cx, |ep_store, _cx| {
1940        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1941    });
1942
1943    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1944    assert_eq!(
1945        captured_request.lock().clone().unwrap().can_collect_data,
1946        true
1947    );
1948
1949    // this has a side effect of registering the buffer to watch for edits
1950    run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1951    assert_eq!(
1952        captured_request.lock().clone().unwrap().can_collect_data,
1953        false
1954    );
1955
1956    private_buffer.update(cx, |private_buffer, cx| {
1957        private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1958    });
1959
1960    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1961    assert_eq!(
1962        captured_request.lock().clone().unwrap().can_collect_data,
1963        false
1964    );
1965
1966    // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1967    // included
1968    buffer.update(cx, |buffer, cx| {
1969        buffer.edit(
1970            [(
1971                0..0,
1972                " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1973            )],
1974            None,
1975            cx,
1976        );
1977    });
1978
1979    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1980    assert_eq!(
1981        captured_request.lock().clone().unwrap().can_collect_data,
1982        true
1983    );
1984}
1985
1986fn init_test(cx: &mut TestAppContext) {
1987    cx.update(|cx| {
1988        let settings_store = SettingsStore::test(cx);
1989        cx.set_global(settings_store);
1990    });
1991}
1992
1993async fn apply_edit_prediction(
1994    buffer_content: &str,
1995    completion_response: &str,
1996    cx: &mut TestAppContext,
1997) -> String {
1998    let fs = project::FakeFs::new(cx.executor());
1999    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2000    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2001    let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
2002    *response.lock() = completion_response.to_string();
2003    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2004    buffer.update(cx, |buffer, cx| {
2005        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2006    });
2007    buffer.read_with(cx, |buffer, _| buffer.text())
2008}
2009
2010async fn run_edit_prediction(
2011    buffer: &Entity<Buffer>,
2012    project: &Entity<Project>,
2013    ep_store: &Entity<EditPredictionStore>,
2014    cx: &mut TestAppContext,
2015) -> EditPrediction {
2016    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2017    ep_store.update(cx, |ep_store, cx| {
2018        ep_store.register_buffer(buffer, &project, cx)
2019    });
2020    cx.background_executor.run_until_parked();
2021    let prediction_task = ep_store.update(cx, |ep_store, cx| {
2022        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2023    });
2024    prediction_task.await.unwrap().unwrap().prediction.unwrap()
2025}
2026
2027async fn make_test_ep_store(
2028    project: &Entity<Project>,
2029    cx: &mut TestAppContext,
2030) -> (
2031    Entity<EditPredictionStore>,
2032    Arc<Mutex<Option<PredictEditsBody>>>,
2033    Arc<Mutex<String>>,
2034) {
2035    let default_response = indoc! {"
2036            ```main.rs
2037            <|start_of_file|>
2038            <|editable_region_start|>
2039            hello world
2040            <|editable_region_end|>
2041            ```"
2042    };
2043    let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2044    let completion_response: Arc<Mutex<String>> =
2045        Arc::new(Mutex::new(default_response.to_string()));
2046    let http_client = FakeHttpClient::create({
2047        let captured_request = captured_request.clone();
2048        let completion_response = completion_response.clone();
2049        let mut next_request_id = 0;
2050        move |req| {
2051            let captured_request = captured_request.clone();
2052            let completion_response = completion_response.clone();
2053            async move {
2054                match (req.method(), req.uri().path()) {
2055                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2056                        .status(200)
2057                        .body(
2058                            serde_json::to_string(&CreateLlmTokenResponse {
2059                                token: LlmToken("the-llm-token".to_string()),
2060                            })
2061                            .unwrap()
2062                            .into(),
2063                        )
2064                        .unwrap()),
2065                    (&Method::POST, "/predict_edits/v2") => {
2066                        let mut request_body = String::new();
2067                        req.into_body().read_to_string(&mut request_body).await?;
2068                        *captured_request.lock() =
2069                            Some(serde_json::from_str(&request_body).unwrap());
2070                        next_request_id += 1;
2071                        Ok(http_client::Response::builder()
2072                            .status(200)
2073                            .body(
2074                                serde_json::to_string(&PredictEditsResponse {
2075                                    request_id: format!("request-{next_request_id}"),
2076                                    output_excerpt: completion_response.lock().clone(),
2077                                })
2078                                .unwrap()
2079                                .into(),
2080                            )
2081                            .unwrap())
2082                    }
2083                    _ => Ok(http_client::Response::builder()
2084                        .status(404)
2085                        .body("Not Found".into())
2086                        .unwrap()),
2087                }
2088            }
2089        }
2090    });
2091
2092    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2093    cx.update(|cx| {
2094        RefreshLlmTokenListener::register(client.clone(), cx);
2095    });
2096    let _server = FakeServer::for_client(42, &client, cx).await;
2097
2098    let ep_store = cx.new(|cx| {
2099        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2100        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2101
2102        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2103        for worktree in worktrees {
2104            let worktree_id = worktree.read(cx).id();
2105            ep_store
2106                .get_or_init_project(project, cx)
2107                .license_detection_watchers
2108                .entry(worktree_id)
2109                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2110        }
2111
2112        ep_store
2113    });
2114
2115    (ep_store, captured_request, completion_response)
2116}
2117
2118fn to_completion_edits(
2119    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2120    buffer: &Entity<Buffer>,
2121    cx: &App,
2122) -> Vec<(Range<Anchor>, Arc<str>)> {
2123    let buffer = buffer.read(cx);
2124    iterator
2125        .into_iter()
2126        .map(|(range, text)| {
2127            (
2128                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2129                text,
2130            )
2131        })
2132        .collect()
2133}
2134
2135fn from_completion_edits(
2136    editor_edits: &[(Range<Anchor>, Arc<str>)],
2137    buffer: &Entity<Buffer>,
2138    cx: &App,
2139) -> Vec<(Range<usize>, Arc<str>)> {
2140    let buffer = buffer.read(cx);
2141    editor_edits
2142        .iter()
2143        .map(|(range, text)| {
2144            (
2145                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2146                text.clone(),
2147            )
2148        })
2149        .collect()
2150}
2151
2152#[gpui::test]
2153async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2154    init_test(cx);
2155
2156    let fs = FakeFs::new(cx.executor());
2157    fs.insert_tree(
2158        "/project",
2159        serde_json::json!({
2160            "main.rs": "fn main() {\n    \n}\n"
2161        }),
2162    )
2163    .await;
2164
2165    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2166
2167    let http_client = FakeHttpClient::create(|_req| async move {
2168        Ok(gpui::http_client::Response::builder()
2169            .status(401)
2170            .body("Unauthorized".into())
2171            .unwrap())
2172    });
2173
2174    let client =
2175        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2176    cx.update(|cx| {
2177        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2178    });
2179
2180    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2181
2182    let buffer = project
2183        .update(cx, |project, cx| {
2184            let path = project
2185                .find_project_path(path!("/project/main.rs"), cx)
2186                .unwrap();
2187            project.open_buffer(path, cx)
2188        })
2189        .await
2190        .unwrap();
2191
2192    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2193    ep_store.update(cx, |ep_store, cx| {
2194        ep_store.register_buffer(&buffer, &project, cx)
2195    });
2196    cx.background_executor.run_until_parked();
2197
2198    let completion_task = ep_store.update(cx, |ep_store, cx| {
2199        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2200        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2201    });
2202
2203    let result = completion_task.await;
2204    assert!(
2205        result.is_err(),
2206        "Without authentication and without custom URL, prediction should fail"
2207    );
2208}
2209
2210#[gpui::test]
2211async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2212    init_test(cx);
2213
2214    let fs = FakeFs::new(cx.executor());
2215    fs.insert_tree(
2216        "/project",
2217        serde_json::json!({
2218            "main.rs": "fn main() {\n    \n}\n"
2219        }),
2220    )
2221    .await;
2222
2223    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2224
2225    let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2226    let predict_called_clone = predict_called.clone();
2227
2228    let http_client = FakeHttpClient::create({
2229        move |req| {
2230            let uri = req.uri().path().to_string();
2231            let predict_called = predict_called_clone.clone();
2232            async move {
2233                if uri.contains("predict") {
2234                    predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2235                    Ok(gpui::http_client::Response::builder()
2236                        .body(
2237                            serde_json::to_string(&open_ai::Response {
2238                                id: "test-123".to_string(),
2239                                object: "chat.completion".to_string(),
2240                                created: 0,
2241                                model: "test".to_string(),
2242                                usage: open_ai::Usage {
2243                                    prompt_tokens: 0,
2244                                    completion_tokens: 0,
2245                                    total_tokens: 0,
2246                                },
2247                                choices: vec![open_ai::Choice {
2248                                    index: 0,
2249                                    message: open_ai::RequestMessage::Assistant {
2250                                        content: Some(open_ai::MessageContent::Plain(
2251                                            indoc! {"
2252                                                ```main.rs
2253                                                <|start_of_file|>
2254                                                <|editable_region_start|>
2255                                                fn main() {
2256                                                    println!(\"Hello, world!\");
2257                                                }
2258                                                <|editable_region_end|>
2259                                                ```
2260                                            "}
2261                                            .to_string(),
2262                                        )),
2263                                        tool_calls: vec![],
2264                                    },
2265                                    finish_reason: Some("stop".to_string()),
2266                                }],
2267                            })
2268                            .unwrap()
2269                            .into(),
2270                        )
2271                        .unwrap())
2272                } else {
2273                    Ok(gpui::http_client::Response::builder()
2274                        .status(401)
2275                        .body("Unauthorized".into())
2276                        .unwrap())
2277                }
2278            }
2279        }
2280    });
2281
2282    let client =
2283        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2284    cx.update(|cx| {
2285        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2286    });
2287
2288    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2289
2290    let buffer = project
2291        .update(cx, |project, cx| {
2292            let path = project
2293                .find_project_path(path!("/project/main.rs"), cx)
2294                .unwrap();
2295            project.open_buffer(path, cx)
2296        })
2297        .await
2298        .unwrap();
2299
2300    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2301    ep_store.update(cx, |ep_store, cx| {
2302        ep_store.register_buffer(&buffer, &project, cx)
2303    });
2304    cx.background_executor.run_until_parked();
2305
2306    let completion_task = ep_store.update(cx, |ep_store, cx| {
2307        ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2308        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2309        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2310    });
2311
2312    let _ = completion_task.await;
2313
2314    assert!(
2315        predict_called.load(std::sync::atomic::Ordering::SeqCst),
2316        "With custom URL, predict endpoint should be called even without authentication"
2317    );
2318}
2319
2320#[gpui::test]
2321fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2322    let buffer = cx.new(|cx| {
2323        Buffer::local(
2324            indoc! {"
2325                zero
2326                one
2327                two
2328                three
2329                four
2330                five
2331                six
2332                seven
2333                eight
2334                nine
2335                ten
2336                eleven
2337                twelve
2338                thirteen
2339                fourteen
2340                fifteen
2341                sixteen
2342                seventeen
2343                eighteen
2344                nineteen
2345                twenty
2346                twenty-one
2347                twenty-two
2348                twenty-three
2349                twenty-four
2350            "},
2351            cx,
2352        )
2353    });
2354
2355    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2356
2357    buffer.update(cx, |buffer, cx| {
2358        let point = Point::new(12, 0);
2359        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2360        let point = Point::new(8, 0);
2361        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2362    });
2363
2364    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2365
2366    let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2367
2368    assert_eq!(
2369        diff,
2370        indoc! {"
2371            @@ -6,10 +6,12 @@
2372             five
2373             six
2374             seven
2375            +FIRST INSERTION
2376             eight
2377             nine
2378             ten
2379             eleven
2380            +SECOND INSERTION
2381             twelve
2382             thirteen
2383             fourteen
2384            "}
2385    );
2386}
2387
2388#[ctor::ctor]
2389fn init_logger() {
2390    zlog::init_test();
2391}