edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
   3use client::{UserStore, test::FakeServer};
   4use clock::{FakeSystemClock, ReplicaId};
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
   8    RejectEditPredictionsBody,
   9    predict_edits_v3::{
  10        RawCompletionChoice, RawCompletionRequest, RawCompletionResponse, RawCompletionUsage,
  11    },
  12};
  13use futures::{
  14    AsyncReadExt, StreamExt,
  15    channel::{mpsc, oneshot},
  16};
  17use gpui::{
  18    Entity, TestAppContext,
  19    http_client::{FakeHttpClient, Response},
  20};
  21use indoc::indoc;
  22use language::Point;
  23use lsp::LanguageServerId;
  24use parking_lot::Mutex;
  25use pretty_assertions::{assert_eq, assert_matches};
  26use project::{FakeFs, Project};
  27use serde_json::json;
  28use settings::SettingsStore;
  29use std::{path::Path, sync::Arc, time::Duration};
  30use util::{path, rel_path::rel_path};
  31use uuid::Uuid;
  32use zeta_prompt::ZetaPromptInput;
  33
  34use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  35
  36#[gpui::test]
  37async fn test_current_state(cx: &mut TestAppContext) {
  38    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  39    let fs = FakeFs::new(cx.executor());
  40    fs.insert_tree(
  41        "/root",
  42        json!({
  43            "1.txt": "Hello!\nHow\nBye\n",
  44            "2.txt": "Hola!\nComo\nAdios\n"
  45        }),
  46    )
  47    .await;
  48    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  49
  50    let buffer1 = project
  51        .update(cx, |project, cx| {
  52            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  53            project.set_active_path(Some(path.clone()), cx);
  54            project.open_buffer(path, cx)
  55        })
  56        .await
  57        .unwrap();
  58    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  59    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  60
  61    ep_store.update(cx, |ep_store, cx| {
  62        ep_store.register_project(&project, cx);
  63        ep_store.register_buffer(&buffer1, &project, cx);
  64    });
  65
  66    // Prediction for current file
  67
  68    ep_store.update(cx, |ep_store, cx| {
  69        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  70    });
  71    let (request, respond_tx) = requests.predict.next().await.unwrap();
  72
  73    respond_tx
  74        .send(model_response(
  75            request,
  76            indoc! {r"
  77                --- a/root/1.txt
  78                +++ b/root/1.txt
  79                @@ ... @@
  80                 Hello!
  81                -How
  82                +How are you?
  83                 Bye
  84            "},
  85        ))
  86        .unwrap();
  87
  88    cx.run_until_parked();
  89
  90    ep_store.update(cx, |ep_store, cx| {
  91        let prediction = ep_store
  92            .prediction_at(&buffer1, None, &project, cx)
  93            .unwrap();
  94        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  95    });
  96
  97    ep_store.update(cx, |ep_store, _cx| {
  98        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
  99    });
 100
 101    // Prediction for diagnostic in another file
 102
 103    let diagnostic = lsp::Diagnostic {
 104        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 105        severity: Some(lsp::DiagnosticSeverity::ERROR),
 106        message: "Sentence is incomplete".to_string(),
 107        ..Default::default()
 108    };
 109
 110    project.update(cx, |project, cx| {
 111        project.lsp_store().update(cx, |lsp_store, cx| {
 112            lsp_store
 113                .update_diagnostics(
 114                    LanguageServerId(0),
 115                    lsp::PublishDiagnosticsParams {
 116                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 117                        diagnostics: vec![diagnostic],
 118                        version: None,
 119                    },
 120                    None,
 121                    language::DiagnosticSourceKind::Pushed,
 122                    &[],
 123                    cx,
 124                )
 125                .unwrap();
 126        });
 127    });
 128
 129    let (request, respond_tx) = requests.predict.next().await.unwrap();
 130    respond_tx
 131        .send(model_response(
 132            request,
 133            indoc! {r#"
 134                --- a/root/2.txt
 135                +++ b/root/2.txt
 136                @@ ... @@
 137                 Hola!
 138                -Como
 139                +Como estas?
 140                 Adios
 141            "#},
 142        ))
 143        .unwrap();
 144    cx.run_until_parked();
 145
 146    ep_store.update(cx, |ep_store, cx| {
 147        let prediction = ep_store
 148            .prediction_at(&buffer1, None, &project, cx)
 149            .unwrap();
 150        assert_matches!(
 151            prediction,
 152            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 153        );
 154    });
 155
 156    let buffer2 = project
 157        .update(cx, |project, cx| {
 158            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 159            project.open_buffer(path, cx)
 160        })
 161        .await
 162        .unwrap();
 163
 164    ep_store.update(cx, |ep_store, cx| {
 165        let prediction = ep_store
 166            .prediction_at(&buffer2, None, &project, cx)
 167            .unwrap();
 168        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 169    });
 170}
 171
 172#[gpui::test]
 173async fn test_simple_request(cx: &mut TestAppContext) {
 174    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 175    let fs = FakeFs::new(cx.executor());
 176    fs.insert_tree(
 177        "/root",
 178        json!({
 179            "foo.md":  "Hello!\nHow\nBye\n"
 180        }),
 181    )
 182    .await;
 183    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 184
 185    let buffer = project
 186        .update(cx, |project, cx| {
 187            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 188            project.open_buffer(path, cx)
 189        })
 190        .await
 191        .unwrap();
 192    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 193    let position = snapshot.anchor_before(language::Point::new(1, 3));
 194
 195    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 196        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 197    });
 198
 199    let (request, respond_tx) = requests.predict.next().await.unwrap();
 200
 201    // TODO Put back when we have a structured request again
 202    // assert_eq!(
 203    //     request.excerpt_path.as_ref(),
 204    //     Path::new(path!("root/foo.md"))
 205    // );
 206    // assert_eq!(
 207    //     request.cursor_point,
 208    //     Point {
 209    //         line: Line(1),
 210    //         column: 3
 211    //     }
 212    // );
 213
 214    respond_tx
 215        .send(model_response(
 216            request,
 217            indoc! { r"
 218                --- a/root/foo.md
 219                +++ b/root/foo.md
 220                @@ ... @@
 221                 Hello!
 222                -How
 223                +How are you?
 224                 Bye
 225            "},
 226        ))
 227        .unwrap();
 228
 229    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 230
 231    assert_eq!(prediction.edits.len(), 1);
 232    assert_eq!(
 233        prediction.edits[0].0.to_point(&snapshot).start,
 234        language::Point::new(1, 3)
 235    );
 236    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 237}
 238
 239#[gpui::test]
 240async fn test_request_events(cx: &mut TestAppContext) {
 241    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 242    let fs = FakeFs::new(cx.executor());
 243    fs.insert_tree(
 244        "/root",
 245        json!({
 246            "foo.md": "Hello!\n\nBye\n"
 247        }),
 248    )
 249    .await;
 250    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 251
 252    let buffer = project
 253        .update(cx, |project, cx| {
 254            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 255            project.open_buffer(path, cx)
 256        })
 257        .await
 258        .unwrap();
 259
 260    ep_store.update(cx, |ep_store, cx| {
 261        ep_store.register_buffer(&buffer, &project, cx);
 262    });
 263
 264    buffer.update(cx, |buffer, cx| {
 265        buffer.edit(vec![(7..7, "How")], None, cx);
 266    });
 267
 268    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 269    let position = snapshot.anchor_before(language::Point::new(1, 3));
 270
 271    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 272        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 273    });
 274
 275    let (request, respond_tx) = requests.predict.next().await.unwrap();
 276
 277    let prompt = prompt_from_request(&request);
 278    assert!(
 279        prompt.contains(indoc! {"
 280        --- a/root/foo.md
 281        +++ b/root/foo.md
 282        @@ -1,3 +1,3 @@
 283         Hello!
 284        -
 285        +How
 286         Bye
 287    "}),
 288        "{prompt}"
 289    );
 290
 291    respond_tx
 292        .send(model_response(
 293            request,
 294            indoc! {r#"
 295                --- a/root/foo.md
 296                +++ b/root/foo.md
 297                @@ ... @@
 298                 Hello!
 299                -How
 300                +How are you?
 301                 Bye
 302        "#},
 303        ))
 304        .unwrap();
 305
 306    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 307
 308    assert_eq!(prediction.edits.len(), 1);
 309    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 310}
 311
 312#[gpui::test]
 313async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 314    let (ep_store, _requests) = init_test_with_fake_client(cx);
 315    let fs = FakeFs::new(cx.executor());
 316    fs.insert_tree(
 317        "/root",
 318        json!({
 319            "foo.md": "Hello!\n\nBye\n"
 320        }),
 321    )
 322    .await;
 323    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 324
 325    let buffer = project
 326        .update(cx, |project, cx| {
 327            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 328            project.open_buffer(path, cx)
 329        })
 330        .await
 331        .unwrap();
 332
 333    ep_store.update(cx, |ep_store, cx| {
 334        ep_store.register_buffer(&buffer, &project, cx);
 335    });
 336
 337    // First burst: insert "How"
 338    buffer.update(cx, |buffer, cx| {
 339        buffer.edit(vec![(7..7, "How")], None, cx);
 340    });
 341
 342    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 343    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 344    cx.run_until_parked();
 345
 346    // Second burst: append " are you?" immediately after "How" on the same line.
 347    //
 348    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 349    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 350    buffer.update(cx, |buffer, cx| {
 351        buffer.edit(vec![(10..10, " are you?")], None, cx);
 352    });
 353
 354    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 355    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 356    buffer.update(cx, |buffer, cx| {
 357        buffer.edit(vec![(19..19, "!")], None, cx);
 358    });
 359
 360    // Without time-based splitting, there is one event.
 361    let events = ep_store.update(cx, |ep_store, cx| {
 362        ep_store.edit_history_for_project(&project, cx)
 363    });
 364    assert_eq!(events.len(), 1);
 365    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 366    assert_eq!(
 367        diff.as_str(),
 368        indoc! {"
 369            @@ -1,3 +1,3 @@
 370             Hello!
 371            -
 372            +How are you?!
 373             Bye
 374        "}
 375    );
 376
 377    // With time-based splitting, there are two distinct events.
 378    let events = ep_store.update(cx, |ep_store, cx| {
 379        ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
 380    });
 381    assert_eq!(events.len(), 2);
 382    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 383    assert_eq!(
 384        diff.as_str(),
 385        indoc! {"
 386            @@ -1,3 +1,3 @@
 387             Hello!
 388            -
 389            +How
 390             Bye
 391        "}
 392    );
 393
 394    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 395    assert_eq!(
 396        diff.as_str(),
 397        indoc! {"
 398            @@ -1,3 +1,3 @@
 399             Hello!
 400            -How
 401            +How are you?!
 402             Bye
 403        "}
 404    );
 405}
 406
 407#[gpui::test]
 408async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
 409    let (ep_store, _requests) = init_test_with_fake_client(cx);
 410    let fs = FakeFs::new(cx.executor());
 411
 412    // Create a file with 30 lines to test line-based coalescing
 413    let content = (1..=30)
 414        .map(|i| format!("Line {}\n", i))
 415        .collect::<String>();
 416    fs.insert_tree(
 417        "/root",
 418        json!({
 419            "foo.md": content
 420        }),
 421    )
 422    .await;
 423    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 424
 425    let buffer = project
 426        .update(cx, |project, cx| {
 427            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 428            project.open_buffer(path, cx)
 429        })
 430        .await
 431        .unwrap();
 432
 433    ep_store.update(cx, |ep_store, cx| {
 434        ep_store.register_buffer(&buffer, &project, cx);
 435    });
 436
 437    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 438    buffer.update(cx, |buffer, cx| {
 439        let start = Point::new(10, 0).to_offset(buffer);
 440        let end = Point::new(13, 0).to_offset(buffer);
 441        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 442    });
 443
 444    let events = ep_store.update(cx, |ep_store, cx| {
 445        ep_store.edit_history_for_project(&project, cx)
 446    });
 447    assert_eq!(
 448        render_events(&events),
 449        indoc! {"
 450            @@ -8,9 +8,8 @@
 451             Line 8
 452             Line 9
 453             Line 10
 454            -Line 11
 455            -Line 12
 456            -Line 13
 457            +Middle A
 458            +Middle B
 459             Line 14
 460             Line 15
 461             Line 16
 462        "},
 463        "After first edit"
 464    );
 465
 466    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 467    // This tests that coalescing considers the START of the existing range
 468    buffer.update(cx, |buffer, cx| {
 469        let offset = Point::new(5, 0).to_offset(buffer);
 470        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 471    });
 472
 473    let events = ep_store.update(cx, |ep_store, cx| {
 474        ep_store.edit_history_for_project(&project, cx)
 475    });
 476    assert_eq!(
 477        render_events(&events),
 478        indoc! {"
 479            @@ -3,14 +3,14 @@
 480             Line 3
 481             Line 4
 482             Line 5
 483            +Above
 484             Line 6
 485             Line 7
 486             Line 8
 487             Line 9
 488             Line 10
 489            -Line 11
 490            -Line 12
 491            -Line 13
 492            +Middle A
 493            +Middle B
 494             Line 14
 495             Line 15
 496             Line 16
 497        "},
 498        "After inserting above (should coalesce)"
 499    );
 500
 501    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 502    // This tests that coalescing considers the END of the existing range
 503    buffer.update(cx, |buffer, cx| {
 504        let offset = Point::new(14, 0).to_offset(buffer);
 505        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 506    });
 507
 508    let events = ep_store.update(cx, |ep_store, cx| {
 509        ep_store.edit_history_for_project(&project, cx)
 510    });
 511    assert_eq!(
 512        render_events(&events),
 513        indoc! {"
 514            @@ -3,15 +3,16 @@
 515             Line 3
 516             Line 4
 517             Line 5
 518            +Above
 519             Line 6
 520             Line 7
 521             Line 8
 522             Line 9
 523             Line 10
 524            -Line 11
 525            -Line 12
 526            -Line 13
 527            +Middle A
 528            +Middle B
 529             Line 14
 530            +Below
 531             Line 15
 532             Line 16
 533             Line 17
 534        "},
 535        "After inserting below (should coalesce)"
 536    );
 537
 538    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 539    // This should NOT coalesce - creates a new event
 540    buffer.update(cx, |buffer, cx| {
 541        let offset = Point::new(25, 0).to_offset(buffer);
 542        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 543    });
 544
 545    let events = ep_store.update(cx, |ep_store, cx| {
 546        ep_store.edit_history_for_project(&project, cx)
 547    });
 548    assert_eq!(
 549        render_events(&events),
 550        indoc! {"
 551            @@ -3,15 +3,16 @@
 552             Line 3
 553             Line 4
 554             Line 5
 555            +Above
 556             Line 6
 557             Line 7
 558             Line 8
 559             Line 9
 560             Line 10
 561            -Line 11
 562            -Line 12
 563            -Line 13
 564            +Middle A
 565            +Middle B
 566             Line 14
 567            +Below
 568             Line 15
 569             Line 16
 570             Line 17
 571
 572            ---
 573            @@ -23,6 +23,7 @@
 574             Line 22
 575             Line 23
 576             Line 24
 577            +Far below
 578             Line 25
 579             Line 26
 580             Line 27
 581        "},
 582        "After inserting far below (should NOT coalesce)"
 583    );
 584}
 585
 586fn render_events(events: &[StoredEvent]) -> String {
 587    events
 588        .iter()
 589        .map(|e| {
 590            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 591            diff.as_str()
 592        })
 593        .collect::<Vec<_>>()
 594        .join("\n---\n")
 595}
 596
 597#[gpui::test]
 598async fn test_empty_prediction(cx: &mut TestAppContext) {
 599    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 600    let fs = FakeFs::new(cx.executor());
 601    fs.insert_tree(
 602        "/root",
 603        json!({
 604            "foo.md":  "Hello!\nHow\nBye\n"
 605        }),
 606    )
 607    .await;
 608    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 609
 610    let buffer = project
 611        .update(cx, |project, cx| {
 612            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 613            project.open_buffer(path, cx)
 614        })
 615        .await
 616        .unwrap();
 617    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 618    let position = snapshot.anchor_before(language::Point::new(1, 3));
 619
 620    ep_store.update(cx, |ep_store, cx| {
 621        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 622    });
 623
 624    let (request, respond_tx) = requests.predict.next().await.unwrap();
 625    let response = model_response(request, "");
 626    let id = response.id.clone();
 627    respond_tx.send(response).unwrap();
 628
 629    cx.run_until_parked();
 630
 631    ep_store.update(cx, |ep_store, cx| {
 632        assert!(
 633            ep_store
 634                .prediction_at(&buffer, None, &project, cx)
 635                .is_none()
 636        );
 637    });
 638
 639    // prediction is reported as rejected
 640    let (reject_request, _) = requests.reject.next().await.unwrap();
 641
 642    assert_eq!(
 643        &reject_request.rejections,
 644        &[EditPredictionRejection {
 645            request_id: id,
 646            reason: EditPredictionRejectReason::Empty,
 647            was_shown: false
 648        }]
 649    );
 650}
 651
 652#[gpui::test]
 653async fn test_interpolated_empty(cx: &mut TestAppContext) {
 654    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 655    let fs = FakeFs::new(cx.executor());
 656    fs.insert_tree(
 657        "/root",
 658        json!({
 659            "foo.md":  "Hello!\nHow\nBye\n"
 660        }),
 661    )
 662    .await;
 663    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 664
 665    let buffer = project
 666        .update(cx, |project, cx| {
 667            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 668            project.open_buffer(path, cx)
 669        })
 670        .await
 671        .unwrap();
 672    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 673    let position = snapshot.anchor_before(language::Point::new(1, 3));
 674
 675    ep_store.update(cx, |ep_store, cx| {
 676        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 677    });
 678
 679    let (request, respond_tx) = requests.predict.next().await.unwrap();
 680
 681    buffer.update(cx, |buffer, cx| {
 682        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 683    });
 684
 685    let response = model_response(request, SIMPLE_DIFF);
 686    let id = response.id.clone();
 687    respond_tx.send(response).unwrap();
 688
 689    cx.run_until_parked();
 690
 691    ep_store.update(cx, |ep_store, cx| {
 692        assert!(
 693            ep_store
 694                .prediction_at(&buffer, None, &project, cx)
 695                .is_none()
 696        );
 697    });
 698
 699    // prediction is reported as rejected
 700    let (reject_request, _) = requests.reject.next().await.unwrap();
 701
 702    assert_eq!(
 703        &reject_request.rejections,
 704        &[EditPredictionRejection {
 705            request_id: id,
 706            reason: EditPredictionRejectReason::InterpolatedEmpty,
 707            was_shown: false
 708        }]
 709    );
 710}
 711
 712const SIMPLE_DIFF: &str = indoc! { r"
 713    --- a/root/foo.md
 714    +++ b/root/foo.md
 715    @@ ... @@
 716     Hello!
 717    -How
 718    +How are you?
 719     Bye
 720"};
 721
 722#[gpui::test]
 723async fn test_replace_current(cx: &mut TestAppContext) {
 724    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 725    let fs = FakeFs::new(cx.executor());
 726    fs.insert_tree(
 727        "/root",
 728        json!({
 729            "foo.md":  "Hello!\nHow\nBye\n"
 730        }),
 731    )
 732    .await;
 733    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 734
 735    let buffer = project
 736        .update(cx, |project, cx| {
 737            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 738            project.open_buffer(path, cx)
 739        })
 740        .await
 741        .unwrap();
 742    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 743    let position = snapshot.anchor_before(language::Point::new(1, 3));
 744
 745    ep_store.update(cx, |ep_store, cx| {
 746        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 747    });
 748
 749    let (request, respond_tx) = requests.predict.next().await.unwrap();
 750    let first_response = model_response(request, SIMPLE_DIFF);
 751    let first_id = first_response.id.clone();
 752    respond_tx.send(first_response).unwrap();
 753
 754    cx.run_until_parked();
 755
 756    ep_store.update(cx, |ep_store, cx| {
 757        assert_eq!(
 758            ep_store
 759                .prediction_at(&buffer, None, &project, cx)
 760                .unwrap()
 761                .id
 762                .0,
 763            first_id
 764        );
 765    });
 766
 767    // a second request is triggered
 768    ep_store.update(cx, |ep_store, cx| {
 769        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 770    });
 771
 772    let (request, respond_tx) = requests.predict.next().await.unwrap();
 773    let second_response = model_response(request, SIMPLE_DIFF);
 774    let second_id = second_response.id.clone();
 775    respond_tx.send(second_response).unwrap();
 776
 777    cx.run_until_parked();
 778
 779    ep_store.update(cx, |ep_store, cx| {
 780        // second replaces first
 781        assert_eq!(
 782            ep_store
 783                .prediction_at(&buffer, None, &project, cx)
 784                .unwrap()
 785                .id
 786                .0,
 787            second_id
 788        );
 789    });
 790
 791    // first is reported as replaced
 792    let (reject_request, _) = requests.reject.next().await.unwrap();
 793
 794    assert_eq!(
 795        &reject_request.rejections,
 796        &[EditPredictionRejection {
 797            request_id: first_id,
 798            reason: EditPredictionRejectReason::Replaced,
 799            was_shown: false
 800        }]
 801    );
 802}
 803
 804#[gpui::test]
 805async fn test_current_preferred(cx: &mut TestAppContext) {
 806    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 807    let fs = FakeFs::new(cx.executor());
 808    fs.insert_tree(
 809        "/root",
 810        json!({
 811            "foo.md":  "Hello!\nHow\nBye\n"
 812        }),
 813    )
 814    .await;
 815    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 816
 817    let buffer = project
 818        .update(cx, |project, cx| {
 819            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 820            project.open_buffer(path, cx)
 821        })
 822        .await
 823        .unwrap();
 824    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 825    let position = snapshot.anchor_before(language::Point::new(1, 3));
 826
 827    ep_store.update(cx, |ep_store, cx| {
 828        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 829    });
 830
 831    let (request, respond_tx) = requests.predict.next().await.unwrap();
 832    let first_response = model_response(request, SIMPLE_DIFF);
 833    let first_id = first_response.id.clone();
 834    respond_tx.send(first_response).unwrap();
 835
 836    cx.run_until_parked();
 837
 838    ep_store.update(cx, |ep_store, cx| {
 839        assert_eq!(
 840            ep_store
 841                .prediction_at(&buffer, None, &project, cx)
 842                .unwrap()
 843                .id
 844                .0,
 845            first_id
 846        );
 847    });
 848
 849    // a second request is triggered
 850    ep_store.update(cx, |ep_store, cx| {
 851        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 852    });
 853
 854    let (request, respond_tx) = requests.predict.next().await.unwrap();
 855    // worse than current prediction
 856    let second_response = model_response(
 857        request,
 858        indoc! { r"
 859            --- a/root/foo.md
 860            +++ b/root/foo.md
 861            @@ ... @@
 862             Hello!
 863            -How
 864            +How are
 865             Bye
 866        "},
 867    );
 868    let second_id = second_response.id.clone();
 869    respond_tx.send(second_response).unwrap();
 870
 871    cx.run_until_parked();
 872
 873    ep_store.update(cx, |ep_store, cx| {
 874        // first is preferred over second
 875        assert_eq!(
 876            ep_store
 877                .prediction_at(&buffer, None, &project, cx)
 878                .unwrap()
 879                .id
 880                .0,
 881            first_id
 882        );
 883    });
 884
 885    // second is reported as rejected
 886    let (reject_request, _) = requests.reject.next().await.unwrap();
 887
 888    assert_eq!(
 889        &reject_request.rejections,
 890        &[EditPredictionRejection {
 891            request_id: second_id,
 892            reason: EditPredictionRejectReason::CurrentPreferred,
 893            was_shown: false
 894        }]
 895    );
 896}
 897
 898#[gpui::test]
 899async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
 900    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 901    let fs = FakeFs::new(cx.executor());
 902    fs.insert_tree(
 903        "/root",
 904        json!({
 905            "foo.md":  "Hello!\nHow\nBye\n"
 906        }),
 907    )
 908    .await;
 909    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 910
 911    let buffer = project
 912        .update(cx, |project, cx| {
 913            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 914            project.open_buffer(path, cx)
 915        })
 916        .await
 917        .unwrap();
 918    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 919    let position = snapshot.anchor_before(language::Point::new(1, 3));
 920
 921    // start two refresh tasks
 922    ep_store.update(cx, |ep_store, cx| {
 923        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 924    });
 925
 926    let (request1, respond_first) = requests.predict.next().await.unwrap();
 927
 928    ep_store.update(cx, |ep_store, cx| {
 929        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 930    });
 931
 932    let (request, respond_second) = requests.predict.next().await.unwrap();
 933
 934    // wait for throttle
 935    cx.run_until_parked();
 936
 937    // second responds first
 938    let second_response = model_response(request, SIMPLE_DIFF);
 939    let second_id = second_response.id.clone();
 940    respond_second.send(second_response).unwrap();
 941
 942    cx.run_until_parked();
 943
 944    ep_store.update(cx, |ep_store, cx| {
 945        // current prediction is second
 946        assert_eq!(
 947            ep_store
 948                .prediction_at(&buffer, None, &project, cx)
 949                .unwrap()
 950                .id
 951                .0,
 952            second_id
 953        );
 954    });
 955
 956    let first_response = model_response(request1, SIMPLE_DIFF);
 957    let first_id = first_response.id.clone();
 958    respond_first.send(first_response).unwrap();
 959
 960    cx.run_until_parked();
 961
 962    ep_store.update(cx, |ep_store, cx| {
 963        // current prediction is still second, since first was cancelled
 964        assert_eq!(
 965            ep_store
 966                .prediction_at(&buffer, None, &project, cx)
 967                .unwrap()
 968                .id
 969                .0,
 970            second_id
 971        );
 972    });
 973
 974    // first is reported as rejected
 975    let (reject_request, _) = requests.reject.next().await.unwrap();
 976
 977    cx.run_until_parked();
 978
 979    assert_eq!(
 980        &reject_request.rejections,
 981        &[EditPredictionRejection {
 982            request_id: first_id,
 983            reason: EditPredictionRejectReason::Canceled,
 984            was_shown: false
 985        }]
 986    );
 987}
 988
 989#[gpui::test]
 990async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
 991    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 992    let fs = FakeFs::new(cx.executor());
 993    fs.insert_tree(
 994        "/root",
 995        json!({
 996            "foo.md":  "Hello!\nHow\nBye\n"
 997        }),
 998    )
 999    .await;
1000    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1001
1002    let buffer = project
1003        .update(cx, |project, cx| {
1004            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1005            project.open_buffer(path, cx)
1006        })
1007        .await
1008        .unwrap();
1009    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1010    let position = snapshot.anchor_before(language::Point::new(1, 3));
1011
1012    // start two refresh tasks
1013    ep_store.update(cx, |ep_store, cx| {
1014        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1015    });
1016
1017    let (request1, respond_first) = requests.predict.next().await.unwrap();
1018
1019    ep_store.update(cx, |ep_store, cx| {
1020        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1021    });
1022
1023    let (request2, respond_second) = requests.predict.next().await.unwrap();
1024
1025    // wait for throttle, so requests are sent
1026    cx.run_until_parked();
1027
1028    ep_store.update(cx, |ep_store, cx| {
1029        // start a third request
1030        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1031
1032        // 2 are pending, so 2nd is cancelled
1033        assert_eq!(
1034            ep_store
1035                .get_or_init_project(&project, cx)
1036                .cancelled_predictions
1037                .iter()
1038                .copied()
1039                .collect::<Vec<_>>(),
1040            [1]
1041        );
1042    });
1043
1044    // wait for throttle
1045    cx.run_until_parked();
1046
1047    let (request3, respond_third) = requests.predict.next().await.unwrap();
1048
1049    let first_response = model_response(request1, SIMPLE_DIFF);
1050    let first_id = first_response.id.clone();
1051    respond_first.send(first_response).unwrap();
1052
1053    cx.run_until_parked();
1054
1055    ep_store.update(cx, |ep_store, cx| {
1056        // current prediction is first
1057        assert_eq!(
1058            ep_store
1059                .prediction_at(&buffer, None, &project, cx)
1060                .unwrap()
1061                .id
1062                .0,
1063            first_id
1064        );
1065    });
1066
1067    let cancelled_response = model_response(request2, SIMPLE_DIFF);
1068    let cancelled_id = cancelled_response.id.clone();
1069    respond_second.send(cancelled_response).unwrap();
1070
1071    cx.run_until_parked();
1072
1073    ep_store.update(cx, |ep_store, cx| {
1074        // current prediction is still first, since second was cancelled
1075        assert_eq!(
1076            ep_store
1077                .prediction_at(&buffer, None, &project, cx)
1078                .unwrap()
1079                .id
1080                .0,
1081            first_id
1082        );
1083    });
1084
1085    let third_response = model_response(request3, SIMPLE_DIFF);
1086    let third_response_id = third_response.id.clone();
1087    respond_third.send(third_response).unwrap();
1088
1089    cx.run_until_parked();
1090
1091    ep_store.update(cx, |ep_store, cx| {
1092        // third completes and replaces first
1093        assert_eq!(
1094            ep_store
1095                .prediction_at(&buffer, None, &project, cx)
1096                .unwrap()
1097                .id
1098                .0,
1099            third_response_id
1100        );
1101    });
1102
1103    // second is reported as rejected
1104    let (reject_request, _) = requests.reject.next().await.unwrap();
1105
1106    cx.run_until_parked();
1107
1108    assert_eq!(
1109        &reject_request.rejections,
1110        &[
1111            EditPredictionRejection {
1112                request_id: cancelled_id,
1113                reason: EditPredictionRejectReason::Canceled,
1114                was_shown: false
1115            },
1116            EditPredictionRejection {
1117                request_id: first_id,
1118                reason: EditPredictionRejectReason::Replaced,
1119                was_shown: false
1120            }
1121        ]
1122    );
1123}
1124
1125#[gpui::test]
1126async fn test_rejections_flushing(cx: &mut TestAppContext) {
1127    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1128
1129    ep_store.update(cx, |ep_store, _cx| {
1130        ep_store.reject_prediction(
1131            EditPredictionId("test-1".into()),
1132            EditPredictionRejectReason::Discarded,
1133            false,
1134        );
1135        ep_store.reject_prediction(
1136            EditPredictionId("test-2".into()),
1137            EditPredictionRejectReason::Canceled,
1138            true,
1139        );
1140    });
1141
1142    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1143    cx.run_until_parked();
1144
1145    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1146    respond_tx.send(()).unwrap();
1147
1148    // batched
1149    assert_eq!(reject_request.rejections.len(), 2);
1150    assert_eq!(
1151        reject_request.rejections[0],
1152        EditPredictionRejection {
1153            request_id: "test-1".to_string(),
1154            reason: EditPredictionRejectReason::Discarded,
1155            was_shown: false
1156        }
1157    );
1158    assert_eq!(
1159        reject_request.rejections[1],
1160        EditPredictionRejection {
1161            request_id: "test-2".to_string(),
1162            reason: EditPredictionRejectReason::Canceled,
1163            was_shown: true
1164        }
1165    );
1166
1167    // Reaching batch size limit sends without debounce
1168    ep_store.update(cx, |ep_store, _cx| {
1169        for i in 0..70 {
1170            ep_store.reject_prediction(
1171                EditPredictionId(format!("batch-{}", i).into()),
1172                EditPredictionRejectReason::Discarded,
1173                false,
1174            );
1175        }
1176    });
1177
1178    // First MAX/2 items are sent immediately
1179    cx.run_until_parked();
1180    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1181    respond_tx.send(()).unwrap();
1182
1183    assert_eq!(reject_request.rejections.len(), 50);
1184    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1185    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1186
1187    // Remaining items are debounced with the next batch
1188    cx.executor().advance_clock(Duration::from_secs(15));
1189    cx.run_until_parked();
1190
1191    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1192    respond_tx.send(()).unwrap();
1193
1194    assert_eq!(reject_request.rejections.len(), 20);
1195    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1196    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1197
1198    // Request failure
1199    ep_store.update(cx, |ep_store, _cx| {
1200        ep_store.reject_prediction(
1201            EditPredictionId("retry-1".into()),
1202            EditPredictionRejectReason::Discarded,
1203            false,
1204        );
1205    });
1206
1207    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1208    cx.run_until_parked();
1209
1210    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1211    assert_eq!(reject_request.rejections.len(), 1);
1212    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1213    // Simulate failure
1214    drop(_respond_tx);
1215
1216    // Add another rejection
1217    ep_store.update(cx, |ep_store, _cx| {
1218        ep_store.reject_prediction(
1219            EditPredictionId("retry-2".into()),
1220            EditPredictionRejectReason::Discarded,
1221            false,
1222        );
1223    });
1224
1225    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1226    cx.run_until_parked();
1227
1228    // Retry should include both the failed item and the new one
1229    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1230    respond_tx.send(()).unwrap();
1231
1232    assert_eq!(reject_request.rejections.len(), 2);
1233    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1234    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1235}
1236
1237// Skipped until we start including diagnostics in prompt
1238// #[gpui::test]
1239// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1240//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1241//     let fs = FakeFs::new(cx.executor());
1242//     fs.insert_tree(
1243//         "/root",
1244//         json!({
1245//             "foo.md": "Hello!\nBye"
1246//         }),
1247//     )
1248//     .await;
1249//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1250
1251//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1252//     let diagnostic = lsp::Diagnostic {
1253//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1254//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1255//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1256//         ..Default::default()
1257//     };
1258
1259//     project.update(cx, |project, cx| {
1260//         project.lsp_store().update(cx, |lsp_store, cx| {
1261//             // Create some diagnostics
1262//             lsp_store
1263//                 .update_diagnostics(
1264//                     LanguageServerId(0),
1265//                     lsp::PublishDiagnosticsParams {
1266//                         uri: path_to_buffer_uri.clone(),
1267//                         diagnostics: vec![diagnostic],
1268//                         version: None,
1269//                     },
1270//                     None,
1271//                     language::DiagnosticSourceKind::Pushed,
1272//                     &[],
1273//                     cx,
1274//                 )
1275//                 .unwrap();
1276//         });
1277//     });
1278
1279//     let buffer = project
1280//         .update(cx, |project, cx| {
1281//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1282//             project.open_buffer(path, cx)
1283//         })
1284//         .await
1285//         .unwrap();
1286
1287//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1288//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1289
1290//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1291//         ep_store.request_prediction(&project, &buffer, position, cx)
1292//     });
1293
1294//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1295
1296//     assert_eq!(request.diagnostic_groups.len(), 1);
1297//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1298//         .unwrap();
1299//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1300//     assert_eq!(
1301//         value,
1302//         json!({
1303//             "entries": [{
1304//                 "range": {
1305//                     "start": 8,
1306//                     "end": 10
1307//                 },
1308//                 "diagnostic": {
1309//                     "source": null,
1310//                     "code": null,
1311//                     "code_description": null,
1312//                     "severity": 1,
1313//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1314//                     "markdown": null,
1315//                     "group_id": 0,
1316//                     "is_primary": true,
1317//                     "is_disk_based": false,
1318//                     "is_unnecessary": false,
1319//                     "source_kind": "Pushed",
1320//                     "data": null,
1321//                     "underline": true
1322//                 }
1323//             }],
1324//             "primary_ix": 0
1325//         })
1326//     );
1327// }
1328
1329// Generate a model response that would apply the given diff to the active file.
1330fn model_response(request: RawCompletionRequest, diff_to_apply: &str) -> RawCompletionResponse {
1331    let prompt = &request.prompt;
1332
1333    let current_marker = "<|fim_middle|>current\n";
1334    let updated_marker = "<|fim_middle|>updated\n";
1335    let suffix_marker = "<|fim_suffix|>\n";
1336    let cursor = "<|user_cursor|>";
1337
1338    let start_ix = current_marker.len() + prompt.find(current_marker).unwrap();
1339    let end_ix = start_ix + &prompt[start_ix..].find(updated_marker).unwrap();
1340    let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
1341    // In v0113_ordered format, the excerpt contains <|fim_suffix|> and suffix content.
1342    // Strip that out to get just the editable region.
1343    let excerpt = if let Some(suffix_pos) = excerpt.find(suffix_marker) {
1344        &excerpt[..suffix_pos]
1345    } else {
1346        &excerpt
1347    };
1348    let new_excerpt = apply_diff_to_string(diff_to_apply, excerpt).unwrap();
1349
1350    RawCompletionResponse {
1351        id: Uuid::new_v4().to_string(),
1352        object: "text_completion".into(),
1353        created: 0,
1354        model: "model".into(),
1355        choices: vec![RawCompletionChoice {
1356            text: new_excerpt,
1357            finish_reason: None,
1358        }],
1359        usage: RawCompletionUsage {
1360            prompt_tokens: 0,
1361            completion_tokens: 0,
1362            total_tokens: 0,
1363        },
1364    }
1365}
1366
1367fn prompt_from_request(request: &RawCompletionRequest) -> &str {
1368    &request.prompt
1369}
1370
1371struct RequestChannels {
1372    predict:
1373        mpsc::UnboundedReceiver<(RawCompletionRequest, oneshot::Sender<RawCompletionResponse>)>,
1374    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1375}
1376
1377fn init_test_with_fake_client(
1378    cx: &mut TestAppContext,
1379) -> (Entity<EditPredictionStore>, RequestChannels) {
1380    cx.update(move |cx| {
1381        let settings_store = SettingsStore::test(cx);
1382        cx.set_global(settings_store);
1383        zlog::init_test();
1384
1385        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1386        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1387
1388        let http_client = FakeHttpClient::create({
1389            move |req| {
1390                let uri = req.uri().path().to_string();
1391                let mut body = req.into_body();
1392                let predict_req_tx = predict_req_tx.clone();
1393                let reject_req_tx = reject_req_tx.clone();
1394                async move {
1395                    let resp = match uri.as_str() {
1396                        "/client/llm_tokens" => serde_json::to_string(&json!({
1397                            "token": "test"
1398                        }))
1399                        .unwrap(),
1400                        "/predict_edits/raw" => {
1401                            let mut buf = Vec::new();
1402                            body.read_to_end(&mut buf).await.ok();
1403                            let req = serde_json::from_slice(&buf).unwrap();
1404
1405                            let (res_tx, res_rx) = oneshot::channel();
1406                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1407                            serde_json::to_string(&res_rx.await?).unwrap()
1408                        }
1409                        "/predict_edits/reject" => {
1410                            let mut buf = Vec::new();
1411                            body.read_to_end(&mut buf).await.ok();
1412                            let req = serde_json::from_slice(&buf).unwrap();
1413
1414                            let (res_tx, res_rx) = oneshot::channel();
1415                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1416                            serde_json::to_string(&res_rx.await?).unwrap()
1417                        }
1418                        _ => {
1419                            panic!("Unexpected path: {}", uri)
1420                        }
1421                    };
1422
1423                    Ok(Response::builder().body(resp.into()).unwrap())
1424                }
1425            }
1426        });
1427
1428        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1429        client.cloud_client().set_credentials(1, "test".into());
1430
1431        language_model::init(client.clone(), cx);
1432
1433        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1434        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1435
1436        (
1437            ep_store,
1438            RequestChannels {
1439                predict: predict_req_rx,
1440                reject: reject_req_rx,
1441            },
1442        )
1443    })
1444}
1445
1446const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1447
1448#[gpui::test]
1449async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1450    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1451    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1452        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1453    });
1454
1455    let edit_preview = cx
1456        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1457        .await;
1458
1459    let prediction = EditPrediction {
1460        edits,
1461        edit_preview,
1462        buffer: buffer.clone(),
1463        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1464        id: EditPredictionId("the-id".into()),
1465        inputs: ZetaPromptInput {
1466            events: Default::default(),
1467            related_files: Default::default(),
1468            cursor_path: Path::new("").into(),
1469            cursor_excerpt: "".into(),
1470            editable_range_in_excerpt: 0..0,
1471            cursor_offset_in_excerpt: 0,
1472        },
1473        buffer_snapshotted_at: Instant::now(),
1474        response_received_at: Instant::now(),
1475    };
1476
1477    cx.update(|cx| {
1478        assert_eq!(
1479            from_completion_edits(
1480                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1481                &buffer,
1482                cx
1483            ),
1484            vec![(2..5, "REM".into()), (9..11, "".into())]
1485        );
1486
1487        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1488        assert_eq!(
1489            from_completion_edits(
1490                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1491                &buffer,
1492                cx
1493            ),
1494            vec![(2..2, "REM".into()), (6..8, "".into())]
1495        );
1496
1497        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1498        assert_eq!(
1499            from_completion_edits(
1500                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1501                &buffer,
1502                cx
1503            ),
1504            vec![(2..5, "REM".into()), (9..11, "".into())]
1505        );
1506
1507        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1508        assert_eq!(
1509            from_completion_edits(
1510                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1511                &buffer,
1512                cx
1513            ),
1514            vec![(3..3, "EM".into()), (7..9, "".into())]
1515        );
1516
1517        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1518        assert_eq!(
1519            from_completion_edits(
1520                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1521                &buffer,
1522                cx
1523            ),
1524            vec![(4..4, "M".into()), (8..10, "".into())]
1525        );
1526
1527        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1528        assert_eq!(
1529            from_completion_edits(
1530                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1531                &buffer,
1532                cx
1533            ),
1534            vec![(9..11, "".into())]
1535        );
1536
1537        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1538        assert_eq!(
1539            from_completion_edits(
1540                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1541                &buffer,
1542                cx
1543            ),
1544            vec![(4..4, "M".into()), (8..10, "".into())]
1545        );
1546
1547        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1548        assert_eq!(
1549            from_completion_edits(
1550                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1551                &buffer,
1552                cx
1553            ),
1554            vec![(4..4, "M".into())]
1555        );
1556
1557        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1558        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1559    })
1560}
1561
1562#[gpui::test]
1563async fn test_clean_up_diff(cx: &mut TestAppContext) {
1564    init_test(cx);
1565
1566    assert_eq!(
1567        apply_edit_prediction(
1568            indoc! {"
1569                    fn main() {
1570                        let word_1 = \"lorem\";
1571                        let range = word.len()..word.len();
1572                    }
1573                "},
1574            indoc! {"
1575                    <|editable_region_start|>
1576                    fn main() {
1577                        let word_1 = \"lorem\";
1578                        let range = word_1.len()..word_1.len();
1579                    }
1580
1581                    <|editable_region_end|>
1582                "},
1583            cx,
1584        )
1585        .await,
1586        indoc! {"
1587                fn main() {
1588                    let word_1 = \"lorem\";
1589                    let range = word_1.len()..word_1.len();
1590                }
1591            "},
1592    );
1593
1594    assert_eq!(
1595        apply_edit_prediction(
1596            indoc! {"
1597                    fn main() {
1598                        let story = \"the quick\"
1599                    }
1600                "},
1601            indoc! {"
1602                    <|editable_region_start|>
1603                    fn main() {
1604                        let story = \"the quick brown fox jumps over the lazy dog\";
1605                    }
1606
1607                    <|editable_region_end|>
1608                "},
1609            cx,
1610        )
1611        .await,
1612        indoc! {"
1613                fn main() {
1614                    let story = \"the quick brown fox jumps over the lazy dog\";
1615                }
1616            "},
1617    );
1618}
1619
1620#[gpui::test]
1621async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1622    init_test(cx);
1623
1624    let buffer_content = "lorem\n";
1625    let completion_response = indoc! {"
1626            ```animals.js
1627            <|start_of_file|>
1628            <|editable_region_start|>
1629            lorem
1630            ipsum
1631            <|editable_region_end|>
1632            ```"};
1633
1634    assert_eq!(
1635        apply_edit_prediction(buffer_content, completion_response, cx).await,
1636        "lorem\nipsum"
1637    );
1638}
1639
1640#[gpui::test]
1641async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1642    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1643    // When the buffer ends without a trailing newline, but the model returns output
1644    // with a trailing newline, zeta2 should normalize both sides before diffing
1645    // so no spurious newline is inserted.
1646    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1647    let fs = FakeFs::new(cx.executor());
1648
1649    // Single line buffer with no trailing newline
1650    fs.insert_tree(
1651        "/root",
1652        json!({
1653            "foo.txt": "hello"
1654        }),
1655    )
1656    .await;
1657    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1658
1659    let buffer = project
1660        .update(cx, |project, cx| {
1661            let path = project
1662                .find_project_path(path!("root/foo.txt"), cx)
1663                .unwrap();
1664            project.open_buffer(path, cx)
1665        })
1666        .await
1667        .unwrap();
1668
1669    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1670    let position = snapshot.anchor_before(language::Point::new(0, 5));
1671
1672    ep_store.update(cx, |ep_store, cx| {
1673        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1674    });
1675
1676    let (_request, respond_tx) = requests.predict.next().await.unwrap();
1677
1678    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1679    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1680    let response = RawCompletionResponse {
1681        id: Uuid::new_v4().to_string(),
1682        object: "text_completion".into(),
1683        created: 0,
1684        model: "model".into(),
1685        choices: vec![RawCompletionChoice {
1686            text: "hello world\n".to_string(),
1687            finish_reason: None,
1688        }],
1689        usage: RawCompletionUsage {
1690            prompt_tokens: 0,
1691            completion_tokens: 0,
1692            total_tokens: 0,
1693        },
1694    };
1695    respond_tx.send(response).unwrap();
1696
1697    cx.run_until_parked();
1698
1699    // The prediction should insert " world" without adding a newline
1700    ep_store.update(cx, |ep_store, cx| {
1701        let prediction = ep_store
1702            .prediction_at(&buffer, None, &project, cx)
1703            .expect("should have prediction");
1704        let edits: Vec<_> = prediction
1705            .edits
1706            .iter()
1707            .map(|(range, text)| {
1708                let snapshot = buffer.read(cx).snapshot();
1709                (range.to_offset(&snapshot), text.clone())
1710            })
1711            .collect();
1712        assert_eq!(edits, vec![(5..5, " world".into())]);
1713    });
1714}
1715
1716#[gpui::test]
1717async fn test_can_collect_data(cx: &mut TestAppContext) {
1718    init_test(cx);
1719
1720    let fs = project::FakeFs::new(cx.executor());
1721    fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1722        .await;
1723
1724    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1725    let buffer = project
1726        .update(cx, |project, cx| {
1727            project.open_local_buffer(path!("/project/src/main.rs"), cx)
1728        })
1729        .await
1730        .unwrap();
1731
1732    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1733    ep_store.update(cx, |ep_store, _cx| {
1734        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1735    });
1736
1737    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1738    assert_eq!(
1739        captured_request.lock().clone().unwrap().can_collect_data,
1740        true
1741    );
1742
1743    ep_store.update(cx, |ep_store, _cx| {
1744        ep_store.data_collection_choice = DataCollectionChoice::Disabled
1745    });
1746
1747    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1748    assert_eq!(
1749        captured_request.lock().clone().unwrap().can_collect_data,
1750        false
1751    );
1752}
1753
1754#[gpui::test]
1755async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1756    init_test(cx);
1757
1758    let fs = project::FakeFs::new(cx.executor());
1759    let project = Project::test(fs.clone(), [], cx).await;
1760
1761    let buffer = cx.new(|_cx| {
1762        Buffer::remote(
1763            language::BufferId::new(1).unwrap(),
1764            ReplicaId::new(1),
1765            language::Capability::ReadWrite,
1766            "fn main() {\n    println!(\"Hello\");\n}",
1767        )
1768    });
1769
1770    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1771    ep_store.update(cx, |ep_store, _cx| {
1772        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1773    });
1774
1775    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1776    assert_eq!(
1777        captured_request.lock().clone().unwrap().can_collect_data,
1778        false
1779    );
1780}
1781
1782#[gpui::test]
1783async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1784    init_test(cx);
1785
1786    let fs = project::FakeFs::new(cx.executor());
1787    fs.insert_tree(
1788        path!("/project"),
1789        json!({
1790            "LICENSE": BSD_0_TXT,
1791            ".env": "SECRET_KEY=secret"
1792        }),
1793    )
1794    .await;
1795
1796    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1797    let buffer = project
1798        .update(cx, |project, cx| {
1799            project.open_local_buffer("/project/.env", cx)
1800        })
1801        .await
1802        .unwrap();
1803
1804    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1805    ep_store.update(cx, |ep_store, _cx| {
1806        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1807    });
1808
1809    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1810    assert_eq!(
1811        captured_request.lock().clone().unwrap().can_collect_data,
1812        false
1813    );
1814}
1815
1816#[gpui::test]
1817async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1818    init_test(cx);
1819
1820    let fs = project::FakeFs::new(cx.executor());
1821    let project = Project::test(fs.clone(), [], cx).await;
1822    let buffer = cx.new(|cx| Buffer::local("", cx));
1823
1824    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1825    ep_store.update(cx, |ep_store, _cx| {
1826        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1827    });
1828
1829    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1830    assert_eq!(
1831        captured_request.lock().clone().unwrap().can_collect_data,
1832        false
1833    );
1834}
1835
1836#[gpui::test]
1837async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1838    init_test(cx);
1839
1840    let fs = project::FakeFs::new(cx.executor());
1841    fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1842        .await;
1843
1844    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1845    let buffer = project
1846        .update(cx, |project, cx| {
1847            project.open_local_buffer("/project/main.rs", cx)
1848        })
1849        .await
1850        .unwrap();
1851
1852    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1853    ep_store.update(cx, |ep_store, _cx| {
1854        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1855    });
1856
1857    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1858    assert_eq!(
1859        captured_request.lock().clone().unwrap().can_collect_data,
1860        false
1861    );
1862}
1863
1864#[gpui::test]
1865async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1866    init_test(cx);
1867
1868    let fs = project::FakeFs::new(cx.executor());
1869    fs.insert_tree(
1870        path!("/open_source_worktree"),
1871        json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1872    )
1873    .await;
1874    fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1875        .await;
1876
1877    let project = Project::test(
1878        fs.clone(),
1879        [
1880            path!("/open_source_worktree").as_ref(),
1881            path!("/closed_source_worktree").as_ref(),
1882        ],
1883        cx,
1884    )
1885    .await;
1886    let buffer = project
1887        .update(cx, |project, cx| {
1888            project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1889        })
1890        .await
1891        .unwrap();
1892
1893    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1894    ep_store.update(cx, |ep_store, _cx| {
1895        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1896    });
1897
1898    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1899    assert_eq!(
1900        captured_request.lock().clone().unwrap().can_collect_data,
1901        true
1902    );
1903
1904    let closed_source_file = project
1905        .update(cx, |project, cx| {
1906            let worktree2 = project
1907                .worktree_for_root_name("closed_source_worktree", cx)
1908                .unwrap();
1909            worktree2.update(cx, |worktree2, cx| {
1910                worktree2.load_file(rel_path("main.rs"), cx)
1911            })
1912        })
1913        .await
1914        .unwrap()
1915        .file;
1916
1917    buffer.update(cx, |buffer, cx| {
1918        buffer.file_updated(closed_source_file, cx);
1919    });
1920
1921    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1922    assert_eq!(
1923        captured_request.lock().clone().unwrap().can_collect_data,
1924        false
1925    );
1926}
1927
1928#[gpui::test]
1929async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1930    init_test(cx);
1931
1932    let fs = project::FakeFs::new(cx.executor());
1933    fs.insert_tree(
1934        path!("/worktree1"),
1935        json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1936    )
1937    .await;
1938    fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1939        .await;
1940
1941    let project = Project::test(
1942        fs.clone(),
1943        [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1944        cx,
1945    )
1946    .await;
1947    let buffer = project
1948        .update(cx, |project, cx| {
1949            project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1950        })
1951        .await
1952        .unwrap();
1953    let private_buffer = project
1954        .update(cx, |project, cx| {
1955            project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1956        })
1957        .await
1958        .unwrap();
1959
1960    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1961    ep_store.update(cx, |ep_store, _cx| {
1962        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1963    });
1964
1965    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1966    assert_eq!(
1967        captured_request.lock().clone().unwrap().can_collect_data,
1968        true
1969    );
1970
1971    // this has a side effect of registering the buffer to watch for edits
1972    run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1973    assert_eq!(
1974        captured_request.lock().clone().unwrap().can_collect_data,
1975        false
1976    );
1977
1978    private_buffer.update(cx, |private_buffer, cx| {
1979        private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1980    });
1981
1982    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1983    assert_eq!(
1984        captured_request.lock().clone().unwrap().can_collect_data,
1985        false
1986    );
1987
1988    // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1989    // included
1990    buffer.update(cx, |buffer, cx| {
1991        buffer.edit(
1992            [(
1993                0..0,
1994                " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1995            )],
1996            None,
1997            cx,
1998        );
1999    });
2000
2001    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2002    assert_eq!(
2003        captured_request.lock().clone().unwrap().can_collect_data,
2004        true
2005    );
2006}
2007
2008fn init_test(cx: &mut TestAppContext) {
2009    cx.update(|cx| {
2010        let settings_store = SettingsStore::test(cx);
2011        cx.set_global(settings_store);
2012    });
2013}
2014
2015async fn apply_edit_prediction(
2016    buffer_content: &str,
2017    completion_response: &str,
2018    cx: &mut TestAppContext,
2019) -> String {
2020    let fs = project::FakeFs::new(cx.executor());
2021    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2022    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2023    let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
2024    *response.lock() = completion_response.to_string();
2025    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2026    buffer.update(cx, |buffer, cx| {
2027        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2028    });
2029    buffer.read_with(cx, |buffer, _| buffer.text())
2030}
2031
2032async fn run_edit_prediction(
2033    buffer: &Entity<Buffer>,
2034    project: &Entity<Project>,
2035    ep_store: &Entity<EditPredictionStore>,
2036    cx: &mut TestAppContext,
2037) -> EditPrediction {
2038    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2039    ep_store.update(cx, |ep_store, cx| {
2040        ep_store.register_buffer(buffer, &project, cx)
2041    });
2042    cx.background_executor.run_until_parked();
2043    let prediction_task = ep_store.update(cx, |ep_store, cx| {
2044        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2045    });
2046    prediction_task.await.unwrap().unwrap().prediction.unwrap()
2047}
2048
2049async fn make_test_ep_store(
2050    project: &Entity<Project>,
2051    cx: &mut TestAppContext,
2052) -> (
2053    Entity<EditPredictionStore>,
2054    Arc<Mutex<Option<PredictEditsBody>>>,
2055    Arc<Mutex<String>>,
2056) {
2057    let default_response = indoc! {"
2058            ```main.rs
2059            <|start_of_file|>
2060            <|editable_region_start|>
2061            hello world
2062            <|editable_region_end|>
2063            ```"
2064    };
2065    let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2066    let completion_response: Arc<Mutex<String>> =
2067        Arc::new(Mutex::new(default_response.to_string()));
2068    let http_client = FakeHttpClient::create({
2069        let captured_request = captured_request.clone();
2070        let completion_response = completion_response.clone();
2071        let mut next_request_id = 0;
2072        move |req| {
2073            let captured_request = captured_request.clone();
2074            let completion_response = completion_response.clone();
2075            async move {
2076                match (req.method(), req.uri().path()) {
2077                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2078                        .status(200)
2079                        .body(
2080                            serde_json::to_string(&CreateLlmTokenResponse {
2081                                token: LlmToken("the-llm-token".to_string()),
2082                            })
2083                            .unwrap()
2084                            .into(),
2085                        )
2086                        .unwrap()),
2087                    (&Method::POST, "/predict_edits/v2") => {
2088                        let mut request_body = String::new();
2089                        req.into_body().read_to_string(&mut request_body).await?;
2090                        *captured_request.lock() =
2091                            Some(serde_json::from_str(&request_body).unwrap());
2092                        next_request_id += 1;
2093                        Ok(http_client::Response::builder()
2094                            .status(200)
2095                            .body(
2096                                serde_json::to_string(&PredictEditsResponse {
2097                                    request_id: format!("request-{next_request_id}"),
2098                                    output_excerpt: completion_response.lock().clone(),
2099                                })
2100                                .unwrap()
2101                                .into(),
2102                            )
2103                            .unwrap())
2104                    }
2105                    _ => Ok(http_client::Response::builder()
2106                        .status(404)
2107                        .body("Not Found".into())
2108                        .unwrap()),
2109                }
2110            }
2111        }
2112    });
2113
2114    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2115    cx.update(|cx| {
2116        RefreshLlmTokenListener::register(client.clone(), cx);
2117    });
2118    let _server = FakeServer::for_client(42, &client, cx).await;
2119
2120    let ep_store = cx.new(|cx| {
2121        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2122        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2123
2124        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2125        for worktree in worktrees {
2126            let worktree_id = worktree.read(cx).id();
2127            ep_store
2128                .get_or_init_project(project, cx)
2129                .license_detection_watchers
2130                .entry(worktree_id)
2131                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2132        }
2133
2134        ep_store
2135    });
2136
2137    (ep_store, captured_request, completion_response)
2138}
2139
2140fn to_completion_edits(
2141    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2142    buffer: &Entity<Buffer>,
2143    cx: &App,
2144) -> Vec<(Range<Anchor>, Arc<str>)> {
2145    let buffer = buffer.read(cx);
2146    iterator
2147        .into_iter()
2148        .map(|(range, text)| {
2149            (
2150                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2151                text,
2152            )
2153        })
2154        .collect()
2155}
2156
2157fn from_completion_edits(
2158    editor_edits: &[(Range<Anchor>, Arc<str>)],
2159    buffer: &Entity<Buffer>,
2160    cx: &App,
2161) -> Vec<(Range<usize>, Arc<str>)> {
2162    let buffer = buffer.read(cx);
2163    editor_edits
2164        .iter()
2165        .map(|(range, text)| {
2166            (
2167                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2168                text.clone(),
2169            )
2170        })
2171        .collect()
2172}
2173
2174#[gpui::test]
2175async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2176    init_test(cx);
2177
2178    let fs = FakeFs::new(cx.executor());
2179    fs.insert_tree(
2180        "/project",
2181        serde_json::json!({
2182            "main.rs": "fn main() {\n    \n}\n"
2183        }),
2184    )
2185    .await;
2186
2187    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2188
2189    let http_client = FakeHttpClient::create(|_req| async move {
2190        Ok(gpui::http_client::Response::builder()
2191            .status(401)
2192            .body("Unauthorized".into())
2193            .unwrap())
2194    });
2195
2196    let client =
2197        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2198    cx.update(|cx| {
2199        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2200    });
2201
2202    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2203
2204    let buffer = project
2205        .update(cx, |project, cx| {
2206            let path = project
2207                .find_project_path(path!("/project/main.rs"), cx)
2208                .unwrap();
2209            project.open_buffer(path, cx)
2210        })
2211        .await
2212        .unwrap();
2213
2214    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2215    ep_store.update(cx, |ep_store, cx| {
2216        ep_store.register_buffer(&buffer, &project, cx)
2217    });
2218    cx.background_executor.run_until_parked();
2219
2220    let completion_task = ep_store.update(cx, |ep_store, cx| {
2221        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2222        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2223    });
2224
2225    let result = completion_task.await;
2226    assert!(
2227        result.is_err(),
2228        "Without authentication and without custom URL, prediction should fail"
2229    );
2230}
2231
2232#[gpui::test]
2233async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2234    init_test(cx);
2235
2236    let fs = FakeFs::new(cx.executor());
2237    fs.insert_tree(
2238        "/project",
2239        serde_json::json!({
2240            "main.rs": "fn main() {\n    \n}\n"
2241        }),
2242    )
2243    .await;
2244
2245    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2246
2247    let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2248    let predict_called_clone = predict_called.clone();
2249
2250    let http_client = FakeHttpClient::create({
2251        move |req| {
2252            let uri = req.uri().path().to_string();
2253            let predict_called = predict_called_clone.clone();
2254            async move {
2255                if uri.contains("predict") {
2256                    predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2257                    Ok(gpui::http_client::Response::builder()
2258                        .body(
2259                            serde_json::to_string(&open_ai::Response {
2260                                id: "test-123".to_string(),
2261                                object: "chat.completion".to_string(),
2262                                created: 0,
2263                                model: "test".to_string(),
2264                                usage: open_ai::Usage {
2265                                    prompt_tokens: 0,
2266                                    completion_tokens: 0,
2267                                    total_tokens: 0,
2268                                },
2269                                choices: vec![open_ai::Choice {
2270                                    index: 0,
2271                                    message: open_ai::RequestMessage::Assistant {
2272                                        content: Some(open_ai::MessageContent::Plain(
2273                                            indoc! {"
2274                                                ```main.rs
2275                                                <|start_of_file|>
2276                                                <|editable_region_start|>
2277                                                fn main() {
2278                                                    println!(\"Hello, world!\");
2279                                                }
2280                                                <|editable_region_end|>
2281                                                ```
2282                                            "}
2283                                            .to_string(),
2284                                        )),
2285                                        tool_calls: vec![],
2286                                    },
2287                                    finish_reason: Some("stop".to_string()),
2288                                }],
2289                            })
2290                            .unwrap()
2291                            .into(),
2292                        )
2293                        .unwrap())
2294                } else {
2295                    Ok(gpui::http_client::Response::builder()
2296                        .status(401)
2297                        .body("Unauthorized".into())
2298                        .unwrap())
2299                }
2300            }
2301        }
2302    });
2303
2304    let client =
2305        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2306    cx.update(|cx| {
2307        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2308    });
2309
2310    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2311
2312    let buffer = project
2313        .update(cx, |project, cx| {
2314            let path = project
2315                .find_project_path(path!("/project/main.rs"), cx)
2316                .unwrap();
2317            project.open_buffer(path, cx)
2318        })
2319        .await
2320        .unwrap();
2321
2322    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2323    ep_store.update(cx, |ep_store, cx| {
2324        ep_store.register_buffer(&buffer, &project, cx)
2325    });
2326    cx.background_executor.run_until_parked();
2327
2328    let completion_task = ep_store.update(cx, |ep_store, cx| {
2329        ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2330        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2331        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2332    });
2333
2334    let _ = completion_task.await;
2335
2336    assert!(
2337        predict_called.load(std::sync::atomic::Ordering::SeqCst),
2338        "With custom URL, predict endpoint should be called even without authentication"
2339    );
2340}
2341
2342#[gpui::test]
2343fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2344    let buffer = cx.new(|cx| {
2345        Buffer::local(
2346            indoc! {"
2347                zero
2348                one
2349                two
2350                three
2351                four
2352                five
2353                six
2354                seven
2355                eight
2356                nine
2357                ten
2358                eleven
2359                twelve
2360                thirteen
2361                fourteen
2362                fifteen
2363                sixteen
2364                seventeen
2365                eighteen
2366                nineteen
2367                twenty
2368                twenty-one
2369                twenty-two
2370                twenty-three
2371                twenty-four
2372            "},
2373            cx,
2374        )
2375    });
2376
2377    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2378
2379    buffer.update(cx, |buffer, cx| {
2380        let point = Point::new(12, 0);
2381        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2382        let point = Point::new(8, 0);
2383        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2384    });
2385
2386    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2387
2388    let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2389
2390    assert_eq!(
2391        diff,
2392        indoc! {"
2393            @@ -6,10 +6,12 @@
2394             five
2395             six
2396             seven
2397            +FIRST INSERTION
2398             eight
2399             nine
2400             ten
2401             eleven
2402            +SECOND INSERTION
2403             twelve
2404             thirteen
2405             fourteen
2406            "}
2407    );
2408}
2409
2410#[ctor::ctor]
2411fn init_logger() {
2412    zlog::init_test();
2413}