edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
   3use client::{UserStore, test::FakeServer};
   4use clock::{FakeSystemClock, ReplicaId};
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
   8    RejectEditPredictionsBody,
   9    predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
  10};
  11use futures::{
  12    AsyncReadExt, StreamExt,
  13    channel::{mpsc, oneshot},
  14};
  15use gpui::App;
  16use gpui::{
  17    Entity, TestAppContext,
  18    http_client::{FakeHttpClient, Response},
  19};
  20use indoc::indoc;
  21use language::{Buffer, Point};
  22use lsp::LanguageServerId;
  23use parking_lot::Mutex;
  24use pretty_assertions::{assert_eq, assert_matches};
  25use project::{FakeFs, Project};
  26use serde_json::json;
  27use settings::SettingsStore;
  28use std::{path::Path, sync::Arc, time::Duration};
  29use util::{path, rel_path::rel_path};
  30use uuid::Uuid;
  31use zeta_prompt::ZetaPromptInput;
  32
  33use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  34
  35#[gpui::test]
  36async fn test_current_state(cx: &mut TestAppContext) {
  37    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  38    let fs = FakeFs::new(cx.executor());
  39    fs.insert_tree(
  40        "/root",
  41        json!({
  42            "1.txt": "Hello!\nHow\nBye\n",
  43            "2.txt": "Hola!\nComo\nAdios\n"
  44        }),
  45    )
  46    .await;
  47    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  48
  49    let buffer1 = project
  50        .update(cx, |project, cx| {
  51            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  52            project.set_active_path(Some(path.clone()), cx);
  53            project.open_buffer(path, cx)
  54        })
  55        .await
  56        .unwrap();
  57    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  58    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  59
  60    ep_store.update(cx, |ep_store, cx| {
  61        ep_store.register_project(&project, cx);
  62        ep_store.register_buffer(&buffer1, &project, cx);
  63    });
  64
  65    // Prediction for current file
  66
  67    ep_store.update(cx, |ep_store, cx| {
  68        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  69    });
  70    let (request, respond_tx) = requests.predict.next().await.unwrap();
  71
  72    respond_tx
  73        .send(model_response(
  74            &request,
  75            indoc! {r"
  76                --- a/root/1.txt
  77                +++ b/root/1.txt
  78                @@ ... @@
  79                 Hello!
  80                -How
  81                +How are you?
  82                 Bye
  83            "},
  84        ))
  85        .unwrap();
  86
  87    cx.run_until_parked();
  88
  89    ep_store.update(cx, |ep_store, cx| {
  90        let prediction = ep_store
  91            .prediction_at(&buffer1, None, &project, cx)
  92            .unwrap();
  93        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  94    });
  95
  96    ep_store.update(cx, |ep_store, _cx| {
  97        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
  98    });
  99
 100    // Prediction for diagnostic in another file
 101
 102    let diagnostic = lsp::Diagnostic {
 103        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 104        severity: Some(lsp::DiagnosticSeverity::ERROR),
 105        message: "Sentence is incomplete".to_string(),
 106        ..Default::default()
 107    };
 108
 109    project.update(cx, |project, cx| {
 110        project.lsp_store().update(cx, |lsp_store, cx| {
 111            lsp_store
 112                .update_diagnostics(
 113                    LanguageServerId(0),
 114                    lsp::PublishDiagnosticsParams {
 115                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 116                        diagnostics: vec![diagnostic],
 117                        version: None,
 118                    },
 119                    None,
 120                    language::DiagnosticSourceKind::Pushed,
 121                    &[],
 122                    cx,
 123                )
 124                .unwrap();
 125        });
 126    });
 127
 128    let (request, respond_tx) = requests.predict.next().await.unwrap();
 129    respond_tx
 130        .send(model_response(
 131            &request,
 132            indoc! {r#"
 133                --- a/root/2.txt
 134                +++ b/root/2.txt
 135                @@ ... @@
 136                 Hola!
 137                -Como
 138                +Como estas?
 139                 Adios
 140            "#},
 141        ))
 142        .unwrap();
 143    cx.run_until_parked();
 144
 145    ep_store.update(cx, |ep_store, cx| {
 146        let prediction = ep_store
 147            .prediction_at(&buffer1, None, &project, cx)
 148            .unwrap();
 149        assert_matches!(
 150            prediction,
 151            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 152        );
 153    });
 154
 155    let buffer2 = project
 156        .update(cx, |project, cx| {
 157            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 158            project.open_buffer(path, cx)
 159        })
 160        .await
 161        .unwrap();
 162
 163    ep_store.update(cx, |ep_store, cx| {
 164        let prediction = ep_store
 165            .prediction_at(&buffer2, None, &project, cx)
 166            .unwrap();
 167        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 168    });
 169}
 170
 171#[gpui::test]
 172async fn test_simple_request(cx: &mut TestAppContext) {
 173    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 174    let fs = FakeFs::new(cx.executor());
 175    fs.insert_tree(
 176        "/root",
 177        json!({
 178            "foo.md":  "Hello!\nHow\nBye\n"
 179        }),
 180    )
 181    .await;
 182    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 183
 184    let buffer = project
 185        .update(cx, |project, cx| {
 186            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 187            project.open_buffer(path, cx)
 188        })
 189        .await
 190        .unwrap();
 191    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 192    let position = snapshot.anchor_before(language::Point::new(1, 3));
 193
 194    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 195        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 196    });
 197
 198    let (request, respond_tx) = requests.predict.next().await.unwrap();
 199
 200    // TODO Put back when we have a structured request again
 201    // assert_eq!(
 202    //     request.excerpt_path.as_ref(),
 203    //     Path::new(path!("root/foo.md"))
 204    // );
 205    // assert_eq!(
 206    //     request.cursor_point,
 207    //     Point {
 208    //         line: Line(1),
 209    //         column: 3
 210    //     }
 211    // );
 212
 213    respond_tx
 214        .send(model_response(
 215            &request,
 216            indoc! { r"
 217                --- a/root/foo.md
 218                +++ b/root/foo.md
 219                @@ ... @@
 220                 Hello!
 221                -How
 222                +How are you?
 223                 Bye
 224            "},
 225        ))
 226        .unwrap();
 227
 228    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 229
 230    assert_eq!(prediction.edits.len(), 1);
 231    assert_eq!(
 232        prediction.edits[0].0.to_point(&snapshot).start,
 233        language::Point::new(1, 3)
 234    );
 235    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 236}
 237
 238#[gpui::test]
 239async fn test_request_events(cx: &mut TestAppContext) {
 240    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 241    let fs = FakeFs::new(cx.executor());
 242    fs.insert_tree(
 243        "/root",
 244        json!({
 245            "foo.md": "Hello!\n\nBye\n"
 246        }),
 247    )
 248    .await;
 249    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 250
 251    let buffer = project
 252        .update(cx, |project, cx| {
 253            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 254            project.open_buffer(path, cx)
 255        })
 256        .await
 257        .unwrap();
 258
 259    ep_store.update(cx, |ep_store, cx| {
 260        ep_store.register_buffer(&buffer, &project, cx);
 261    });
 262
 263    buffer.update(cx, |buffer, cx| {
 264        buffer.edit(vec![(7..7, "How")], None, cx);
 265    });
 266
 267    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 268    let position = snapshot.anchor_before(language::Point::new(1, 3));
 269
 270    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 271        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 272    });
 273
 274    let (request, respond_tx) = requests.predict.next().await.unwrap();
 275
 276    let prompt = prompt_from_request(&request);
 277    assert!(
 278        prompt.contains(indoc! {"
 279        --- a/root/foo.md
 280        +++ b/root/foo.md
 281        @@ -1,3 +1,3 @@
 282         Hello!
 283        -
 284        +How
 285         Bye
 286    "}),
 287        "{prompt}"
 288    );
 289
 290    respond_tx
 291        .send(model_response(
 292            &request,
 293            indoc! {r#"
 294                --- a/root/foo.md
 295                +++ b/root/foo.md
 296                @@ ... @@
 297                 Hello!
 298                -How
 299                +How are you?
 300                 Bye
 301        "#},
 302        ))
 303        .unwrap();
 304
 305    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 306
 307    assert_eq!(prediction.edits.len(), 1);
 308    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 309}
 310
 311#[gpui::test]
 312async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 313    let (ep_store, _requests) = init_test_with_fake_client(cx);
 314    let fs = FakeFs::new(cx.executor());
 315    fs.insert_tree(
 316        "/root",
 317        json!({
 318            "foo.md": "Hello!\n\nBye\n"
 319        }),
 320    )
 321    .await;
 322    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 323
 324    let buffer = project
 325        .update(cx, |project, cx| {
 326            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 327            project.open_buffer(path, cx)
 328        })
 329        .await
 330        .unwrap();
 331
 332    ep_store.update(cx, |ep_store, cx| {
 333        ep_store.register_buffer(&buffer, &project, cx);
 334    });
 335
 336    // First burst: insert "How"
 337    buffer.update(cx, |buffer, cx| {
 338        buffer.edit(vec![(7..7, "How")], None, cx);
 339    });
 340
 341    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 342    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 343    cx.run_until_parked();
 344
 345    // Second burst: append " are you?" immediately after "How" on the same line.
 346    //
 347    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 348    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 349    buffer.update(cx, |buffer, cx| {
 350        buffer.edit(vec![(10..10, " are you?")], None, cx);
 351    });
 352
 353    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 354    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 355    buffer.update(cx, |buffer, cx| {
 356        buffer.edit(vec![(19..19, "!")], None, cx);
 357    });
 358
 359    // Without time-based splitting, there is one event.
 360    let events = ep_store.update(cx, |ep_store, cx| {
 361        ep_store.edit_history_for_project(&project, cx)
 362    });
 363    assert_eq!(events.len(), 1);
 364    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 365    assert_eq!(
 366        diff.as_str(),
 367        indoc! {"
 368            @@ -1,3 +1,3 @@
 369             Hello!
 370            -
 371            +How are you?!
 372             Bye
 373        "}
 374    );
 375
 376    // With time-based splitting, there are two distinct events.
 377    let events = ep_store.update(cx, |ep_store, cx| {
 378        ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
 379    });
 380    assert_eq!(events.len(), 2);
 381    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 382    assert_eq!(
 383        diff.as_str(),
 384        indoc! {"
 385            @@ -1,3 +1,3 @@
 386             Hello!
 387            -
 388            +How
 389             Bye
 390        "}
 391    );
 392
 393    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 394    assert_eq!(
 395        diff.as_str(),
 396        indoc! {"
 397            @@ -1,3 +1,3 @@
 398             Hello!
 399            -How
 400            +How are you?!
 401             Bye
 402        "}
 403    );
 404}
 405
 406#[gpui::test]
 407async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
 408    let (ep_store, _requests) = init_test_with_fake_client(cx);
 409    let fs = FakeFs::new(cx.executor());
 410
 411    // Create a file with 30 lines to test line-based coalescing
 412    let content = (1..=30)
 413        .map(|i| format!("Line {}\n", i))
 414        .collect::<String>();
 415    fs.insert_tree(
 416        "/root",
 417        json!({
 418            "foo.md": content
 419        }),
 420    )
 421    .await;
 422    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 423
 424    let buffer = project
 425        .update(cx, |project, cx| {
 426            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 427            project.open_buffer(path, cx)
 428        })
 429        .await
 430        .unwrap();
 431
 432    ep_store.update(cx, |ep_store, cx| {
 433        ep_store.register_buffer(&buffer, &project, cx);
 434    });
 435
 436    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 437    buffer.update(cx, |buffer, cx| {
 438        let start = Point::new(10, 0).to_offset(buffer);
 439        let end = Point::new(13, 0).to_offset(buffer);
 440        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 441    });
 442
 443    let events = ep_store.update(cx, |ep_store, cx| {
 444        ep_store.edit_history_for_project(&project, cx)
 445    });
 446    assert_eq!(
 447        render_events(&events),
 448        indoc! {"
 449            @@ -8,9 +8,8 @@
 450             Line 8
 451             Line 9
 452             Line 10
 453            -Line 11
 454            -Line 12
 455            -Line 13
 456            +Middle A
 457            +Middle B
 458             Line 14
 459             Line 15
 460             Line 16
 461        "},
 462        "After first edit"
 463    );
 464
 465    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 466    // This tests that coalescing considers the START of the existing range
 467    buffer.update(cx, |buffer, cx| {
 468        let offset = Point::new(5, 0).to_offset(buffer);
 469        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 470    });
 471
 472    let events = ep_store.update(cx, |ep_store, cx| {
 473        ep_store.edit_history_for_project(&project, cx)
 474    });
 475    assert_eq!(
 476        render_events(&events),
 477        indoc! {"
 478            @@ -3,14 +3,14 @@
 479             Line 3
 480             Line 4
 481             Line 5
 482            +Above
 483             Line 6
 484             Line 7
 485             Line 8
 486             Line 9
 487             Line 10
 488            -Line 11
 489            -Line 12
 490            -Line 13
 491            +Middle A
 492            +Middle B
 493             Line 14
 494             Line 15
 495             Line 16
 496        "},
 497        "After inserting above (should coalesce)"
 498    );
 499
 500    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 501    // This tests that coalescing considers the END of the existing range
 502    buffer.update(cx, |buffer, cx| {
 503        let offset = Point::new(14, 0).to_offset(buffer);
 504        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 505    });
 506
 507    let events = ep_store.update(cx, |ep_store, cx| {
 508        ep_store.edit_history_for_project(&project, cx)
 509    });
 510    assert_eq!(
 511        render_events(&events),
 512        indoc! {"
 513            @@ -3,15 +3,16 @@
 514             Line 3
 515             Line 4
 516             Line 5
 517            +Above
 518             Line 6
 519             Line 7
 520             Line 8
 521             Line 9
 522             Line 10
 523            -Line 11
 524            -Line 12
 525            -Line 13
 526            +Middle A
 527            +Middle B
 528             Line 14
 529            +Below
 530             Line 15
 531             Line 16
 532             Line 17
 533        "},
 534        "After inserting below (should coalesce)"
 535    );
 536
 537    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 538    // This should NOT coalesce - creates a new event
 539    buffer.update(cx, |buffer, cx| {
 540        let offset = Point::new(25, 0).to_offset(buffer);
 541        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 542    });
 543
 544    let events = ep_store.update(cx, |ep_store, cx| {
 545        ep_store.edit_history_for_project(&project, cx)
 546    });
 547    assert_eq!(
 548        render_events(&events),
 549        indoc! {"
 550            @@ -3,15 +3,16 @@
 551             Line 3
 552             Line 4
 553             Line 5
 554            +Above
 555             Line 6
 556             Line 7
 557             Line 8
 558             Line 9
 559             Line 10
 560            -Line 11
 561            -Line 12
 562            -Line 13
 563            +Middle A
 564            +Middle B
 565             Line 14
 566            +Below
 567             Line 15
 568             Line 16
 569             Line 17
 570
 571            ---
 572            @@ -23,6 +23,7 @@
 573             Line 22
 574             Line 23
 575             Line 24
 576            +Far below
 577             Line 25
 578             Line 26
 579             Line 27
 580        "},
 581        "After inserting far below (should NOT coalesce)"
 582    );
 583}
 584
 585fn render_events(events: &[StoredEvent]) -> String {
 586    events
 587        .iter()
 588        .map(|e| {
 589            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 590            diff.as_str()
 591        })
 592        .collect::<Vec<_>>()
 593        .join("\n---\n")
 594}
 595
 596#[gpui::test]
 597async fn test_empty_prediction(cx: &mut TestAppContext) {
 598    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 599    let fs = FakeFs::new(cx.executor());
 600    fs.insert_tree(
 601        "/root",
 602        json!({
 603            "foo.md":  "Hello!\nHow\nBye\n"
 604        }),
 605    )
 606    .await;
 607    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 608
 609    let buffer = project
 610        .update(cx, |project, cx| {
 611            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 612            project.open_buffer(path, cx)
 613        })
 614        .await
 615        .unwrap();
 616    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 617    let position = snapshot.anchor_before(language::Point::new(1, 3));
 618
 619    ep_store.update(cx, |ep_store, cx| {
 620        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 621    });
 622
 623    let (request, respond_tx) = requests.predict.next().await.unwrap();
 624    let response = model_response(&request, "");
 625    let id = response.request_id.clone();
 626    respond_tx.send(response).unwrap();
 627
 628    cx.run_until_parked();
 629
 630    ep_store.update(cx, |ep_store, cx| {
 631        assert!(
 632            ep_store
 633                .prediction_at(&buffer, None, &project, cx)
 634                .is_none()
 635        );
 636    });
 637
 638    // prediction is reported as rejected
 639    let (reject_request, _) = requests.reject.next().await.unwrap();
 640
 641    assert_eq!(
 642        &reject_request.rejections,
 643        &[EditPredictionRejection {
 644            request_id: id,
 645            reason: EditPredictionRejectReason::Empty,
 646            was_shown: false
 647        }]
 648    );
 649}
 650
 651#[gpui::test]
 652async fn test_interpolated_empty(cx: &mut TestAppContext) {
 653    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 654    let fs = FakeFs::new(cx.executor());
 655    fs.insert_tree(
 656        "/root",
 657        json!({
 658            "foo.md":  "Hello!\nHow\nBye\n"
 659        }),
 660    )
 661    .await;
 662    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 663
 664    let buffer = project
 665        .update(cx, |project, cx| {
 666            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 667            project.open_buffer(path, cx)
 668        })
 669        .await
 670        .unwrap();
 671    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 672    let position = snapshot.anchor_before(language::Point::new(1, 3));
 673
 674    ep_store.update(cx, |ep_store, cx| {
 675        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 676    });
 677
 678    let (request, respond_tx) = requests.predict.next().await.unwrap();
 679
 680    buffer.update(cx, |buffer, cx| {
 681        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 682    });
 683
 684    let response = model_response(&request, SIMPLE_DIFF);
 685    let id = response.request_id.clone();
 686    respond_tx.send(response).unwrap();
 687
 688    cx.run_until_parked();
 689
 690    ep_store.update(cx, |ep_store, cx| {
 691        assert!(
 692            ep_store
 693                .prediction_at(&buffer, None, &project, cx)
 694                .is_none()
 695        );
 696    });
 697
 698    // prediction is reported as rejected
 699    let (reject_request, _) = requests.reject.next().await.unwrap();
 700
 701    assert_eq!(
 702        &reject_request.rejections,
 703        &[EditPredictionRejection {
 704            request_id: id,
 705            reason: EditPredictionRejectReason::InterpolatedEmpty,
 706            was_shown: false
 707        }]
 708    );
 709}
 710
 711const SIMPLE_DIFF: &str = indoc! { r"
 712    --- a/root/foo.md
 713    +++ b/root/foo.md
 714    @@ ... @@
 715     Hello!
 716    -How
 717    +How are you?
 718     Bye
 719"};
 720
 721#[gpui::test]
 722async fn test_replace_current(cx: &mut TestAppContext) {
 723    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 724    let fs = FakeFs::new(cx.executor());
 725    fs.insert_tree(
 726        "/root",
 727        json!({
 728            "foo.md":  "Hello!\nHow\nBye\n"
 729        }),
 730    )
 731    .await;
 732    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 733
 734    let buffer = project
 735        .update(cx, |project, cx| {
 736            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 737            project.open_buffer(path, cx)
 738        })
 739        .await
 740        .unwrap();
 741    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 742    let position = snapshot.anchor_before(language::Point::new(1, 3));
 743
 744    ep_store.update(cx, |ep_store, cx| {
 745        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 746    });
 747
 748    let (request, respond_tx) = requests.predict.next().await.unwrap();
 749    let first_response = model_response(&request, SIMPLE_DIFF);
 750    let first_id = first_response.request_id.clone();
 751    respond_tx.send(first_response).unwrap();
 752
 753    cx.run_until_parked();
 754
 755    ep_store.update(cx, |ep_store, cx| {
 756        assert_eq!(
 757            ep_store
 758                .prediction_at(&buffer, None, &project, cx)
 759                .unwrap()
 760                .id
 761                .0,
 762            first_id
 763        );
 764    });
 765
 766    // a second request is triggered
 767    ep_store.update(cx, |ep_store, cx| {
 768        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 769    });
 770
 771    let (request, respond_tx) = requests.predict.next().await.unwrap();
 772    let second_response = model_response(&request, SIMPLE_DIFF);
 773    let second_id = second_response.request_id.clone();
 774    respond_tx.send(second_response).unwrap();
 775
 776    cx.run_until_parked();
 777
 778    ep_store.update(cx, |ep_store, cx| {
 779        // second replaces first
 780        assert_eq!(
 781            ep_store
 782                .prediction_at(&buffer, None, &project, cx)
 783                .unwrap()
 784                .id
 785                .0,
 786            second_id
 787        );
 788    });
 789
 790    // first is reported as replaced
 791    let (reject_request, _) = requests.reject.next().await.unwrap();
 792
 793    assert_eq!(
 794        &reject_request.rejections,
 795        &[EditPredictionRejection {
 796            request_id: first_id,
 797            reason: EditPredictionRejectReason::Replaced,
 798            was_shown: false
 799        }]
 800    );
 801}
 802
 803#[gpui::test]
 804async fn test_current_preferred(cx: &mut TestAppContext) {
 805    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 806    let fs = FakeFs::new(cx.executor());
 807    fs.insert_tree(
 808        "/root",
 809        json!({
 810            "foo.md":  "Hello!\nHow\nBye\n"
 811        }),
 812    )
 813    .await;
 814    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 815
 816    let buffer = project
 817        .update(cx, |project, cx| {
 818            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 819            project.open_buffer(path, cx)
 820        })
 821        .await
 822        .unwrap();
 823    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 824    let position = snapshot.anchor_before(language::Point::new(1, 3));
 825
 826    ep_store.update(cx, |ep_store, cx| {
 827        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 828    });
 829
 830    let (request, respond_tx) = requests.predict.next().await.unwrap();
 831    let first_response = model_response(&request, SIMPLE_DIFF);
 832    let first_id = first_response.request_id.clone();
 833    respond_tx.send(first_response).unwrap();
 834
 835    cx.run_until_parked();
 836
 837    ep_store.update(cx, |ep_store, cx| {
 838        assert_eq!(
 839            ep_store
 840                .prediction_at(&buffer, None, &project, cx)
 841                .unwrap()
 842                .id
 843                .0,
 844            first_id
 845        );
 846    });
 847
 848    // a second request is triggered
 849    ep_store.update(cx, |ep_store, cx| {
 850        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 851    });
 852
 853    let (request, respond_tx) = requests.predict.next().await.unwrap();
 854    // worse than current prediction
 855    let second_response = model_response(
 856        &request,
 857        indoc! { r"
 858            --- a/root/foo.md
 859            +++ b/root/foo.md
 860            @@ ... @@
 861             Hello!
 862            -How
 863            +How are
 864             Bye
 865        "},
 866    );
 867    let second_id = second_response.request_id.clone();
 868    respond_tx.send(second_response).unwrap();
 869
 870    cx.run_until_parked();
 871
 872    ep_store.update(cx, |ep_store, cx| {
 873        // first is preferred over second
 874        assert_eq!(
 875            ep_store
 876                .prediction_at(&buffer, None, &project, cx)
 877                .unwrap()
 878                .id
 879                .0,
 880            first_id
 881        );
 882    });
 883
 884    // second is reported as rejected
 885    let (reject_request, _) = requests.reject.next().await.unwrap();
 886
 887    assert_eq!(
 888        &reject_request.rejections,
 889        &[EditPredictionRejection {
 890            request_id: second_id,
 891            reason: EditPredictionRejectReason::CurrentPreferred,
 892            was_shown: false
 893        }]
 894    );
 895}
 896
 897#[gpui::test]
 898async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
 899    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 900    let fs = FakeFs::new(cx.executor());
 901    fs.insert_tree(
 902        "/root",
 903        json!({
 904            "foo.md":  "Hello!\nHow\nBye\n"
 905        }),
 906    )
 907    .await;
 908    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 909
 910    let buffer = project
 911        .update(cx, |project, cx| {
 912            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 913            project.open_buffer(path, cx)
 914        })
 915        .await
 916        .unwrap();
 917    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 918    let position = snapshot.anchor_before(language::Point::new(1, 3));
 919
 920    // start two refresh tasks
 921    ep_store.update(cx, |ep_store, cx| {
 922        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 923    });
 924
 925    let (request1, respond_first) = requests.predict.next().await.unwrap();
 926
 927    ep_store.update(cx, |ep_store, cx| {
 928        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 929    });
 930
 931    let (request, respond_second) = requests.predict.next().await.unwrap();
 932
 933    // wait for throttle
 934    cx.run_until_parked();
 935
 936    // second responds first
 937    let second_response = model_response(&request, SIMPLE_DIFF);
 938    let second_id = second_response.request_id.clone();
 939    respond_second.send(second_response).unwrap();
 940
 941    cx.run_until_parked();
 942
 943    ep_store.update(cx, |ep_store, cx| {
 944        // current prediction is second
 945        assert_eq!(
 946            ep_store
 947                .prediction_at(&buffer, None, &project, cx)
 948                .unwrap()
 949                .id
 950                .0,
 951            second_id
 952        );
 953    });
 954
 955    let first_response = model_response(&request1, SIMPLE_DIFF);
 956    let first_id = first_response.request_id.clone();
 957    respond_first.send(first_response).unwrap();
 958
 959    cx.run_until_parked();
 960
 961    ep_store.update(cx, |ep_store, cx| {
 962        // current prediction is still second, since first was cancelled
 963        assert_eq!(
 964            ep_store
 965                .prediction_at(&buffer, None, &project, cx)
 966                .unwrap()
 967                .id
 968                .0,
 969            second_id
 970        );
 971    });
 972
 973    // first is reported as rejected
 974    let (reject_request, _) = requests.reject.next().await.unwrap();
 975
 976    cx.run_until_parked();
 977
 978    assert_eq!(
 979        &reject_request.rejections,
 980        &[EditPredictionRejection {
 981            request_id: first_id,
 982            reason: EditPredictionRejectReason::Canceled,
 983            was_shown: false
 984        }]
 985    );
 986}
 987
 988#[gpui::test]
 989async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
 990    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 991    let fs = FakeFs::new(cx.executor());
 992    fs.insert_tree(
 993        "/root",
 994        json!({
 995            "foo.md":  "Hello!\nHow\nBye\n"
 996        }),
 997    )
 998    .await;
 999    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1000
1001    let buffer = project
1002        .update(cx, |project, cx| {
1003            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1004            project.open_buffer(path, cx)
1005        })
1006        .await
1007        .unwrap();
1008    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1009    let position = snapshot.anchor_before(language::Point::new(1, 3));
1010
1011    // start two refresh tasks
1012    ep_store.update(cx, |ep_store, cx| {
1013        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1014    });
1015
1016    let (request1, respond_first) = requests.predict.next().await.unwrap();
1017
1018    ep_store.update(cx, |ep_store, cx| {
1019        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1020    });
1021
1022    let (request2, respond_second) = requests.predict.next().await.unwrap();
1023
1024    // wait for throttle, so requests are sent
1025    cx.run_until_parked();
1026
1027    ep_store.update(cx, |ep_store, cx| {
1028        // start a third request
1029        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1030
1031        // 2 are pending, so 2nd is cancelled
1032        assert_eq!(
1033            ep_store
1034                .get_or_init_project(&project, cx)
1035                .cancelled_predictions
1036                .iter()
1037                .copied()
1038                .collect::<Vec<_>>(),
1039            [1]
1040        );
1041    });
1042
1043    // wait for throttle
1044    cx.run_until_parked();
1045
1046    let (request3, respond_third) = requests.predict.next().await.unwrap();
1047
1048    let first_response = model_response(&request1, SIMPLE_DIFF);
1049    let first_id = first_response.request_id.clone();
1050    respond_first.send(first_response).unwrap();
1051
1052    cx.run_until_parked();
1053
1054    ep_store.update(cx, |ep_store, cx| {
1055        // current prediction is first
1056        assert_eq!(
1057            ep_store
1058                .prediction_at(&buffer, None, &project, cx)
1059                .unwrap()
1060                .id
1061                .0,
1062            first_id
1063        );
1064    });
1065
1066    let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1067    let cancelled_id = cancelled_response.request_id.clone();
1068    respond_second.send(cancelled_response).unwrap();
1069
1070    cx.run_until_parked();
1071
1072    ep_store.update(cx, |ep_store, cx| {
1073        // current prediction is still first, since second was cancelled
1074        assert_eq!(
1075            ep_store
1076                .prediction_at(&buffer, None, &project, cx)
1077                .unwrap()
1078                .id
1079                .0,
1080            first_id
1081        );
1082    });
1083
1084    let third_response = model_response(&request3, SIMPLE_DIFF);
1085    let third_response_id = third_response.request_id.clone();
1086    respond_third.send(third_response).unwrap();
1087
1088    cx.run_until_parked();
1089
1090    ep_store.update(cx, |ep_store, cx| {
1091        // third completes and replaces first
1092        assert_eq!(
1093            ep_store
1094                .prediction_at(&buffer, None, &project, cx)
1095                .unwrap()
1096                .id
1097                .0,
1098            third_response_id
1099        );
1100    });
1101
1102    // second is reported as rejected
1103    let (reject_request, _) = requests.reject.next().await.unwrap();
1104
1105    cx.run_until_parked();
1106
1107    assert_eq!(
1108        &reject_request.rejections,
1109        &[
1110            EditPredictionRejection {
1111                request_id: cancelled_id,
1112                reason: EditPredictionRejectReason::Canceled,
1113                was_shown: false
1114            },
1115            EditPredictionRejection {
1116                request_id: first_id,
1117                reason: EditPredictionRejectReason::Replaced,
1118                was_shown: false
1119            }
1120        ]
1121    );
1122}
1123
1124#[gpui::test]
1125async fn test_rejections_flushing(cx: &mut TestAppContext) {
1126    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1127
1128    ep_store.update(cx, |ep_store, _cx| {
1129        ep_store.reject_prediction(
1130            EditPredictionId("test-1".into()),
1131            EditPredictionRejectReason::Discarded,
1132            false,
1133        );
1134        ep_store.reject_prediction(
1135            EditPredictionId("test-2".into()),
1136            EditPredictionRejectReason::Canceled,
1137            true,
1138        );
1139    });
1140
1141    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1142    cx.run_until_parked();
1143
1144    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1145    respond_tx.send(()).unwrap();
1146
1147    // batched
1148    assert_eq!(reject_request.rejections.len(), 2);
1149    assert_eq!(
1150        reject_request.rejections[0],
1151        EditPredictionRejection {
1152            request_id: "test-1".to_string(),
1153            reason: EditPredictionRejectReason::Discarded,
1154            was_shown: false
1155        }
1156    );
1157    assert_eq!(
1158        reject_request.rejections[1],
1159        EditPredictionRejection {
1160            request_id: "test-2".to_string(),
1161            reason: EditPredictionRejectReason::Canceled,
1162            was_shown: true
1163        }
1164    );
1165
1166    // Reaching batch size limit sends without debounce
1167    ep_store.update(cx, |ep_store, _cx| {
1168        for i in 0..70 {
1169            ep_store.reject_prediction(
1170                EditPredictionId(format!("batch-{}", i).into()),
1171                EditPredictionRejectReason::Discarded,
1172                false,
1173            );
1174        }
1175    });
1176
1177    // First MAX/2 items are sent immediately
1178    cx.run_until_parked();
1179    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1180    respond_tx.send(()).unwrap();
1181
1182    assert_eq!(reject_request.rejections.len(), 50);
1183    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1184    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1185
1186    // Remaining items are debounced with the next batch
1187    cx.executor().advance_clock(Duration::from_secs(15));
1188    cx.run_until_parked();
1189
1190    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1191    respond_tx.send(()).unwrap();
1192
1193    assert_eq!(reject_request.rejections.len(), 20);
1194    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1195    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1196
1197    // Request failure
1198    ep_store.update(cx, |ep_store, _cx| {
1199        ep_store.reject_prediction(
1200            EditPredictionId("retry-1".into()),
1201            EditPredictionRejectReason::Discarded,
1202            false,
1203        );
1204    });
1205
1206    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1207    cx.run_until_parked();
1208
1209    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1210    assert_eq!(reject_request.rejections.len(), 1);
1211    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1212    // Simulate failure
1213    drop(_respond_tx);
1214
1215    // Add another rejection
1216    ep_store.update(cx, |ep_store, _cx| {
1217        ep_store.reject_prediction(
1218            EditPredictionId("retry-2".into()),
1219            EditPredictionRejectReason::Discarded,
1220            false,
1221        );
1222    });
1223
1224    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1225    cx.run_until_parked();
1226
1227    // Retry should include both the failed item and the new one
1228    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1229    respond_tx.send(()).unwrap();
1230
1231    assert_eq!(reject_request.rejections.len(), 2);
1232    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1233    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1234}
1235
1236// Skipped until we start including diagnostics in prompt
1237// #[gpui::test]
1238// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1239//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1240//     let fs = FakeFs::new(cx.executor());
1241//     fs.insert_tree(
1242//         "/root",
1243//         json!({
1244//             "foo.md": "Hello!\nBye"
1245//         }),
1246//     )
1247//     .await;
1248//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1249
1250//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1251//     let diagnostic = lsp::Diagnostic {
1252//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1253//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1254//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1255//         ..Default::default()
1256//     };
1257
1258//     project.update(cx, |project, cx| {
1259//         project.lsp_store().update(cx, |lsp_store, cx| {
1260//             // Create some diagnostics
1261//             lsp_store
1262//                 .update_diagnostics(
1263//                     LanguageServerId(0),
1264//                     lsp::PublishDiagnosticsParams {
1265//                         uri: path_to_buffer_uri.clone(),
1266//                         diagnostics: vec![diagnostic],
1267//                         version: None,
1268//                     },
1269//                     None,
1270//                     language::DiagnosticSourceKind::Pushed,
1271//                     &[],
1272//                     cx,
1273//                 )
1274//                 .unwrap();
1275//         });
1276//     });
1277
1278//     let buffer = project
1279//         .update(cx, |project, cx| {
1280//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1281//             project.open_buffer(path, cx)
1282//         })
1283//         .await
1284//         .unwrap();
1285
1286//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1287//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1288
1289//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1290//         ep_store.request_prediction(&project, &buffer, position, cx)
1291//     });
1292
1293//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1294
1295//     assert_eq!(request.diagnostic_groups.len(), 1);
1296//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1297//         .unwrap();
1298//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1299//     assert_eq!(
1300//         value,
1301//         json!({
1302//             "entries": [{
1303//                 "range": {
1304//                     "start": 8,
1305//                     "end": 10
1306//                 },
1307//                 "diagnostic": {
1308//                     "source": null,
1309//                     "code": null,
1310//                     "code_description": null,
1311//                     "severity": 1,
1312//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1313//                     "markdown": null,
1314//                     "group_id": 0,
1315//                     "is_primary": true,
1316//                     "is_disk_based": false,
1317//                     "is_unnecessary": false,
1318//                     "source_kind": "Pushed",
1319//                     "data": null,
1320//                     "underline": true
1321//                 }
1322//             }],
1323//             "primary_ix": 0
1324//         })
1325//     );
1326// }
1327
1328// Generate a model response that would apply the given diff to the active file.
1329fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1330    let excerpt =
1331        request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1332    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1333
1334    PredictEditsV3Response {
1335        request_id: Uuid::new_v4().to_string(),
1336        output: new_excerpt,
1337    }
1338}
1339
1340fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1341    zeta_prompt::format_zeta_prompt(&request.input, request.prompt_version)
1342}
1343
1344struct RequestChannels {
1345    predict: mpsc::UnboundedReceiver<(
1346        PredictEditsV3Request,
1347        oneshot::Sender<PredictEditsV3Response>,
1348    )>,
1349    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1350}
1351
1352fn init_test_with_fake_client(
1353    cx: &mut TestAppContext,
1354) -> (Entity<EditPredictionStore>, RequestChannels) {
1355    cx.update(move |cx| {
1356        let settings_store = SettingsStore::test(cx);
1357        cx.set_global(settings_store);
1358        zlog::init_test();
1359
1360        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1361        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1362
1363        let http_client = FakeHttpClient::create({
1364            move |req| {
1365                let uri = req.uri().path().to_string();
1366                let mut body = req.into_body();
1367                let predict_req_tx = predict_req_tx.clone();
1368                let reject_req_tx = reject_req_tx.clone();
1369                async move {
1370                    let resp = match uri.as_str() {
1371                        "/client/llm_tokens" => serde_json::to_string(&json!({
1372                            "token": "test"
1373                        }))
1374                        .unwrap(),
1375                        "/predict_edits/v3" => {
1376                            let mut buf = Vec::new();
1377                            body.read_to_end(&mut buf).await.ok();
1378                            let req = serde_json::from_slice(&buf).unwrap();
1379
1380                            let (res_tx, res_rx) = oneshot::channel();
1381                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1382                            serde_json::to_string(&res_rx.await?).unwrap()
1383                        }
1384                        "/predict_edits/reject" => {
1385                            let mut buf = Vec::new();
1386                            body.read_to_end(&mut buf).await.ok();
1387                            let req = serde_json::from_slice(&buf).unwrap();
1388
1389                            let (res_tx, res_rx) = oneshot::channel();
1390                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1391                            serde_json::to_string(&res_rx.await?).unwrap()
1392                        }
1393                        _ => {
1394                            panic!("Unexpected path: {}", uri)
1395                        }
1396                    };
1397
1398                    Ok(Response::builder().body(resp.into()).unwrap())
1399                }
1400            }
1401        });
1402
1403        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1404        client.cloud_client().set_credentials(1, "test".into());
1405
1406        language_model::init(client.clone(), cx);
1407
1408        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1409        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1410
1411        (
1412            ep_store,
1413            RequestChannels {
1414                predict: predict_req_rx,
1415                reject: reject_req_rx,
1416            },
1417        )
1418    })
1419}
1420
1421const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1422
1423#[gpui::test]
1424async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1425    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1426    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1427        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1428    });
1429
1430    let edit_preview = cx
1431        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1432        .await;
1433
1434    let prediction = EditPrediction {
1435        edits,
1436        edit_preview,
1437        buffer: buffer.clone(),
1438        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1439        id: EditPredictionId("the-id".into()),
1440        inputs: ZetaPromptInput {
1441            events: Default::default(),
1442            related_files: Default::default(),
1443            cursor_path: Path::new("").into(),
1444            cursor_excerpt: "".into(),
1445            editable_range_in_excerpt: 0..0,
1446            cursor_offset_in_excerpt: 0,
1447        },
1448        buffer_snapshotted_at: Instant::now(),
1449        response_received_at: Instant::now(),
1450    };
1451
1452    cx.update(|cx| {
1453        assert_eq!(
1454            from_completion_edits(
1455                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1456                &buffer,
1457                cx
1458            ),
1459            vec![(2..5, "REM".into()), (9..11, "".into())]
1460        );
1461
1462        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1463        assert_eq!(
1464            from_completion_edits(
1465                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1466                &buffer,
1467                cx
1468            ),
1469            vec![(2..2, "REM".into()), (6..8, "".into())]
1470        );
1471
1472        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1473        assert_eq!(
1474            from_completion_edits(
1475                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1476                &buffer,
1477                cx
1478            ),
1479            vec![(2..5, "REM".into()), (9..11, "".into())]
1480        );
1481
1482        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1483        assert_eq!(
1484            from_completion_edits(
1485                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1486                &buffer,
1487                cx
1488            ),
1489            vec![(3..3, "EM".into()), (7..9, "".into())]
1490        );
1491
1492        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1493        assert_eq!(
1494            from_completion_edits(
1495                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1496                &buffer,
1497                cx
1498            ),
1499            vec![(4..4, "M".into()), (8..10, "".into())]
1500        );
1501
1502        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1503        assert_eq!(
1504            from_completion_edits(
1505                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1506                &buffer,
1507                cx
1508            ),
1509            vec![(9..11, "".into())]
1510        );
1511
1512        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1513        assert_eq!(
1514            from_completion_edits(
1515                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1516                &buffer,
1517                cx
1518            ),
1519            vec![(4..4, "M".into()), (8..10, "".into())]
1520        );
1521
1522        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1523        assert_eq!(
1524            from_completion_edits(
1525                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1526                &buffer,
1527                cx
1528            ),
1529            vec![(4..4, "M".into())]
1530        );
1531
1532        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1533        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1534    })
1535}
1536
1537#[gpui::test]
1538async fn test_clean_up_diff(cx: &mut TestAppContext) {
1539    init_test(cx);
1540
1541    assert_eq!(
1542        apply_edit_prediction(
1543            indoc! {"
1544                    fn main() {
1545                        let word_1 = \"lorem\";
1546                        let range = word.len()..word.len();
1547                    }
1548                "},
1549            indoc! {"
1550                    <|editable_region_start|>
1551                    fn main() {
1552                        let word_1 = \"lorem\";
1553                        let range = word_1.len()..word_1.len();
1554                    }
1555
1556                    <|editable_region_end|>
1557                "},
1558            cx,
1559        )
1560        .await,
1561        indoc! {"
1562                fn main() {
1563                    let word_1 = \"lorem\";
1564                    let range = word_1.len()..word_1.len();
1565                }
1566            "},
1567    );
1568
1569    assert_eq!(
1570        apply_edit_prediction(
1571            indoc! {"
1572                    fn main() {
1573                        let story = \"the quick\"
1574                    }
1575                "},
1576            indoc! {"
1577                    <|editable_region_start|>
1578                    fn main() {
1579                        let story = \"the quick brown fox jumps over the lazy dog\";
1580                    }
1581
1582                    <|editable_region_end|>
1583                "},
1584            cx,
1585        )
1586        .await,
1587        indoc! {"
1588                fn main() {
1589                    let story = \"the quick brown fox jumps over the lazy dog\";
1590                }
1591            "},
1592    );
1593}
1594
1595#[gpui::test]
1596async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1597    init_test(cx);
1598
1599    let buffer_content = "lorem\n";
1600    let completion_response = indoc! {"
1601            ```animals.js
1602            <|start_of_file|>
1603            <|editable_region_start|>
1604            lorem
1605            ipsum
1606            <|editable_region_end|>
1607            ```"};
1608
1609    assert_eq!(
1610        apply_edit_prediction(buffer_content, completion_response, cx).await,
1611        "lorem\nipsum"
1612    );
1613}
1614
1615#[gpui::test]
1616async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1617    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1618    // When the buffer ends without a trailing newline, but the model returns output
1619    // with a trailing newline, zeta2 should normalize both sides before diffing
1620    // so no spurious newline is inserted.
1621    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1622    let fs = FakeFs::new(cx.executor());
1623
1624    // Single line buffer with no trailing newline
1625    fs.insert_tree(
1626        "/root",
1627        json!({
1628            "foo.txt": "hello"
1629        }),
1630    )
1631    .await;
1632    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1633
1634    let buffer = project
1635        .update(cx, |project, cx| {
1636            let path = project
1637                .find_project_path(path!("root/foo.txt"), cx)
1638                .unwrap();
1639            project.open_buffer(path, cx)
1640        })
1641        .await
1642        .unwrap();
1643
1644    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1645    let position = snapshot.anchor_before(language::Point::new(0, 5));
1646
1647    ep_store.update(cx, |ep_store, cx| {
1648        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1649    });
1650
1651    let (_request, respond_tx) = requests.predict.next().await.unwrap();
1652
1653    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1654    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1655    let response = PredictEditsV3Response {
1656        request_id: Uuid::new_v4().to_string(),
1657        output: "hello world\n".to_string(),
1658    };
1659    respond_tx.send(response).unwrap();
1660
1661    cx.run_until_parked();
1662
1663    // The prediction should insert " world" without adding a newline
1664    ep_store.update(cx, |ep_store, cx| {
1665        let prediction = ep_store
1666            .prediction_at(&buffer, None, &project, cx)
1667            .expect("should have prediction");
1668        let edits: Vec<_> = prediction
1669            .edits
1670            .iter()
1671            .map(|(range, text)| {
1672                let snapshot = buffer.read(cx).snapshot();
1673                (range.to_offset(&snapshot), text.clone())
1674            })
1675            .collect();
1676        assert_eq!(edits, vec![(5..5, " world".into())]);
1677    });
1678}
1679
1680#[gpui::test]
1681async fn test_can_collect_data(cx: &mut TestAppContext) {
1682    init_test(cx);
1683
1684    let fs = project::FakeFs::new(cx.executor());
1685    fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1686        .await;
1687
1688    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1689    let buffer = project
1690        .update(cx, |project, cx| {
1691            project.open_local_buffer(path!("/project/src/main.rs"), cx)
1692        })
1693        .await
1694        .unwrap();
1695
1696    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1697    ep_store.update(cx, |ep_store, _cx| {
1698        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1699    });
1700
1701    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1702    assert_eq!(
1703        captured_request.lock().clone().unwrap().can_collect_data,
1704        true
1705    );
1706
1707    ep_store.update(cx, |ep_store, _cx| {
1708        ep_store.data_collection_choice = DataCollectionChoice::Disabled
1709    });
1710
1711    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1712    assert_eq!(
1713        captured_request.lock().clone().unwrap().can_collect_data,
1714        false
1715    );
1716}
1717
1718#[gpui::test]
1719async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1720    init_test(cx);
1721
1722    let fs = project::FakeFs::new(cx.executor());
1723    let project = Project::test(fs.clone(), [], cx).await;
1724
1725    let buffer = cx.new(|_cx| {
1726        Buffer::remote(
1727            language::BufferId::new(1).unwrap(),
1728            ReplicaId::new(1),
1729            language::Capability::ReadWrite,
1730            "fn main() {\n    println!(\"Hello\");\n}",
1731        )
1732    });
1733
1734    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1735    ep_store.update(cx, |ep_store, _cx| {
1736        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1737    });
1738
1739    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1740    assert_eq!(
1741        captured_request.lock().clone().unwrap().can_collect_data,
1742        false
1743    );
1744}
1745
1746#[gpui::test]
1747async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1748    init_test(cx);
1749
1750    let fs = project::FakeFs::new(cx.executor());
1751    fs.insert_tree(
1752        path!("/project"),
1753        json!({
1754            "LICENSE": BSD_0_TXT,
1755            ".env": "SECRET_KEY=secret"
1756        }),
1757    )
1758    .await;
1759
1760    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1761    let buffer = project
1762        .update(cx, |project, cx| {
1763            project.open_local_buffer("/project/.env", cx)
1764        })
1765        .await
1766        .unwrap();
1767
1768    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1769    ep_store.update(cx, |ep_store, _cx| {
1770        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1771    });
1772
1773    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1774    assert_eq!(
1775        captured_request.lock().clone().unwrap().can_collect_data,
1776        false
1777    );
1778}
1779
1780#[gpui::test]
1781async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1782    init_test(cx);
1783
1784    let fs = project::FakeFs::new(cx.executor());
1785    let project = Project::test(fs.clone(), [], cx).await;
1786    let buffer = cx.new(|cx| Buffer::local("", cx));
1787
1788    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1789    ep_store.update(cx, |ep_store, _cx| {
1790        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1791    });
1792
1793    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1794    assert_eq!(
1795        captured_request.lock().clone().unwrap().can_collect_data,
1796        false
1797    );
1798}
1799
1800#[gpui::test]
1801async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1802    init_test(cx);
1803
1804    let fs = project::FakeFs::new(cx.executor());
1805    fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1806        .await;
1807
1808    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1809    let buffer = project
1810        .update(cx, |project, cx| {
1811            project.open_local_buffer("/project/main.rs", cx)
1812        })
1813        .await
1814        .unwrap();
1815
1816    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1817    ep_store.update(cx, |ep_store, _cx| {
1818        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1819    });
1820
1821    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1822    assert_eq!(
1823        captured_request.lock().clone().unwrap().can_collect_data,
1824        false
1825    );
1826}
1827
1828#[gpui::test]
1829async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1830    init_test(cx);
1831
1832    let fs = project::FakeFs::new(cx.executor());
1833    fs.insert_tree(
1834        path!("/open_source_worktree"),
1835        json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1836    )
1837    .await;
1838    fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1839        .await;
1840
1841    let project = Project::test(
1842        fs.clone(),
1843        [
1844            path!("/open_source_worktree").as_ref(),
1845            path!("/closed_source_worktree").as_ref(),
1846        ],
1847        cx,
1848    )
1849    .await;
1850    let buffer = project
1851        .update(cx, |project, cx| {
1852            project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1853        })
1854        .await
1855        .unwrap();
1856
1857    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1858    ep_store.update(cx, |ep_store, _cx| {
1859        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1860    });
1861
1862    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1863    assert_eq!(
1864        captured_request.lock().clone().unwrap().can_collect_data,
1865        true
1866    );
1867
1868    let closed_source_file = project
1869        .update(cx, |project, cx| {
1870            let worktree2 = project
1871                .worktree_for_root_name("closed_source_worktree", cx)
1872                .unwrap();
1873            worktree2.update(cx, |worktree2, cx| {
1874                worktree2.load_file(rel_path("main.rs"), cx)
1875            })
1876        })
1877        .await
1878        .unwrap()
1879        .file;
1880
1881    buffer.update(cx, |buffer, cx| {
1882        buffer.file_updated(closed_source_file, cx);
1883    });
1884
1885    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1886    assert_eq!(
1887        captured_request.lock().clone().unwrap().can_collect_data,
1888        false
1889    );
1890}
1891
1892#[gpui::test]
1893async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1894    init_test(cx);
1895
1896    let fs = project::FakeFs::new(cx.executor());
1897    fs.insert_tree(
1898        path!("/worktree1"),
1899        json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1900    )
1901    .await;
1902    fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1903        .await;
1904
1905    let project = Project::test(
1906        fs.clone(),
1907        [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1908        cx,
1909    )
1910    .await;
1911    let buffer = project
1912        .update(cx, |project, cx| {
1913            project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1914        })
1915        .await
1916        .unwrap();
1917    let private_buffer = project
1918        .update(cx, |project, cx| {
1919            project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1920        })
1921        .await
1922        .unwrap();
1923
1924    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1925    ep_store.update(cx, |ep_store, _cx| {
1926        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1927    });
1928
1929    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1930    assert_eq!(
1931        captured_request.lock().clone().unwrap().can_collect_data,
1932        true
1933    );
1934
1935    // this has a side effect of registering the buffer to watch for edits
1936    run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1937    assert_eq!(
1938        captured_request.lock().clone().unwrap().can_collect_data,
1939        false
1940    );
1941
1942    private_buffer.update(cx, |private_buffer, cx| {
1943        private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1944    });
1945
1946    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1947    assert_eq!(
1948        captured_request.lock().clone().unwrap().can_collect_data,
1949        false
1950    );
1951
1952    // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1953    // included
1954    buffer.update(cx, |buffer, cx| {
1955        buffer.edit(
1956            [(
1957                0..0,
1958                " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1959            )],
1960            None,
1961            cx,
1962        );
1963    });
1964
1965    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1966    assert_eq!(
1967        captured_request.lock().clone().unwrap().can_collect_data,
1968        true
1969    );
1970}
1971
1972fn init_test(cx: &mut TestAppContext) {
1973    cx.update(|cx| {
1974        let settings_store = SettingsStore::test(cx);
1975        cx.set_global(settings_store);
1976    });
1977}
1978
1979async fn apply_edit_prediction(
1980    buffer_content: &str,
1981    completion_response: &str,
1982    cx: &mut TestAppContext,
1983) -> String {
1984    let fs = project::FakeFs::new(cx.executor());
1985    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1986    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1987    let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1988    *response.lock() = completion_response.to_string();
1989    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1990    buffer.update(cx, |buffer, cx| {
1991        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1992    });
1993    buffer.read_with(cx, |buffer, _| buffer.text())
1994}
1995
1996async fn run_edit_prediction(
1997    buffer: &Entity<Buffer>,
1998    project: &Entity<Project>,
1999    ep_store: &Entity<EditPredictionStore>,
2000    cx: &mut TestAppContext,
2001) -> EditPrediction {
2002    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2003    ep_store.update(cx, |ep_store, cx| {
2004        ep_store.register_buffer(buffer, &project, cx)
2005    });
2006    cx.background_executor.run_until_parked();
2007    let prediction_task = ep_store.update(cx, |ep_store, cx| {
2008        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2009    });
2010    prediction_task.await.unwrap().unwrap().prediction.unwrap()
2011}
2012
2013async fn make_test_ep_store(
2014    project: &Entity<Project>,
2015    cx: &mut TestAppContext,
2016) -> (
2017    Entity<EditPredictionStore>,
2018    Arc<Mutex<Option<PredictEditsBody>>>,
2019    Arc<Mutex<String>>,
2020) {
2021    let default_response = indoc! {"
2022            ```main.rs
2023            <|start_of_file|>
2024            <|editable_region_start|>
2025            hello world
2026            <|editable_region_end|>
2027            ```"
2028    };
2029    let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2030    let completion_response: Arc<Mutex<String>> =
2031        Arc::new(Mutex::new(default_response.to_string()));
2032    let http_client = FakeHttpClient::create({
2033        let captured_request = captured_request.clone();
2034        let completion_response = completion_response.clone();
2035        let mut next_request_id = 0;
2036        move |req| {
2037            let captured_request = captured_request.clone();
2038            let completion_response = completion_response.clone();
2039            async move {
2040                match (req.method(), req.uri().path()) {
2041                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2042                        .status(200)
2043                        .body(
2044                            serde_json::to_string(&CreateLlmTokenResponse {
2045                                token: LlmToken("the-llm-token".to_string()),
2046                            })
2047                            .unwrap()
2048                            .into(),
2049                        )
2050                        .unwrap()),
2051                    (&Method::POST, "/predict_edits/v2") => {
2052                        let mut request_body = String::new();
2053                        req.into_body().read_to_string(&mut request_body).await?;
2054                        *captured_request.lock() =
2055                            Some(serde_json::from_str(&request_body).unwrap());
2056                        next_request_id += 1;
2057                        Ok(http_client::Response::builder()
2058                            .status(200)
2059                            .body(
2060                                serde_json::to_string(&PredictEditsResponse {
2061                                    request_id: format!("request-{next_request_id}"),
2062                                    output_excerpt: completion_response.lock().clone(),
2063                                })
2064                                .unwrap()
2065                                .into(),
2066                            )
2067                            .unwrap())
2068                    }
2069                    _ => Ok(http_client::Response::builder()
2070                        .status(404)
2071                        .body("Not Found".into())
2072                        .unwrap()),
2073                }
2074            }
2075        }
2076    });
2077
2078    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2079    cx.update(|cx| {
2080        RefreshLlmTokenListener::register(client.clone(), cx);
2081    });
2082    let _server = FakeServer::for_client(42, &client, cx).await;
2083
2084    let ep_store = cx.new(|cx| {
2085        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2086        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2087
2088        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2089        for worktree in worktrees {
2090            let worktree_id = worktree.read(cx).id();
2091            ep_store
2092                .get_or_init_project(project, cx)
2093                .license_detection_watchers
2094                .entry(worktree_id)
2095                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2096        }
2097
2098        ep_store
2099    });
2100
2101    (ep_store, captured_request, completion_response)
2102}
2103
2104fn to_completion_edits(
2105    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2106    buffer: &Entity<Buffer>,
2107    cx: &App,
2108) -> Vec<(Range<Anchor>, Arc<str>)> {
2109    let buffer = buffer.read(cx);
2110    iterator
2111        .into_iter()
2112        .map(|(range, text)| {
2113            (
2114                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2115                text,
2116            )
2117        })
2118        .collect()
2119}
2120
2121fn from_completion_edits(
2122    editor_edits: &[(Range<Anchor>, Arc<str>)],
2123    buffer: &Entity<Buffer>,
2124    cx: &App,
2125) -> Vec<(Range<usize>, Arc<str>)> {
2126    let buffer = buffer.read(cx);
2127    editor_edits
2128        .iter()
2129        .map(|(range, text)| {
2130            (
2131                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2132                text.clone(),
2133            )
2134        })
2135        .collect()
2136}
2137
2138#[gpui::test]
2139async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2140    init_test(cx);
2141
2142    let fs = FakeFs::new(cx.executor());
2143    fs.insert_tree(
2144        "/project",
2145        serde_json::json!({
2146            "main.rs": "fn main() {\n    \n}\n"
2147        }),
2148    )
2149    .await;
2150
2151    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2152
2153    let http_client = FakeHttpClient::create(|_req| async move {
2154        Ok(gpui::http_client::Response::builder()
2155            .status(401)
2156            .body("Unauthorized".into())
2157            .unwrap())
2158    });
2159
2160    let client =
2161        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2162    cx.update(|cx| {
2163        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2164    });
2165
2166    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2167
2168    let buffer = project
2169        .update(cx, |project, cx| {
2170            let path = project
2171                .find_project_path(path!("/project/main.rs"), cx)
2172                .unwrap();
2173            project.open_buffer(path, cx)
2174        })
2175        .await
2176        .unwrap();
2177
2178    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2179    ep_store.update(cx, |ep_store, cx| {
2180        ep_store.register_buffer(&buffer, &project, cx)
2181    });
2182    cx.background_executor.run_until_parked();
2183
2184    let completion_task = ep_store.update(cx, |ep_store, cx| {
2185        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2186        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2187    });
2188
2189    let result = completion_task.await;
2190    assert!(
2191        result.is_err(),
2192        "Without authentication and without custom URL, prediction should fail"
2193    );
2194}
2195
2196#[gpui::test]
2197async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2198    init_test(cx);
2199
2200    let fs = FakeFs::new(cx.executor());
2201    fs.insert_tree(
2202        "/project",
2203        serde_json::json!({
2204            "main.rs": "fn main() {\n    \n}\n"
2205        }),
2206    )
2207    .await;
2208
2209    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2210
2211    let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2212    let predict_called_clone = predict_called.clone();
2213
2214    let http_client = FakeHttpClient::create({
2215        move |req| {
2216            let uri = req.uri().path().to_string();
2217            let predict_called = predict_called_clone.clone();
2218            async move {
2219                if uri.contains("predict") {
2220                    predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2221                    Ok(gpui::http_client::Response::builder()
2222                        .body(
2223                            serde_json::to_string(&open_ai::Response {
2224                                id: "test-123".to_string(),
2225                                object: "chat.completion".to_string(),
2226                                created: 0,
2227                                model: "test".to_string(),
2228                                usage: open_ai::Usage {
2229                                    prompt_tokens: 0,
2230                                    completion_tokens: 0,
2231                                    total_tokens: 0,
2232                                },
2233                                choices: vec![open_ai::Choice {
2234                                    index: 0,
2235                                    message: open_ai::RequestMessage::Assistant {
2236                                        content: Some(open_ai::MessageContent::Plain(
2237                                            indoc! {"
2238                                                ```main.rs
2239                                                <|start_of_file|>
2240                                                <|editable_region_start|>
2241                                                fn main() {
2242                                                    println!(\"Hello, world!\");
2243                                                }
2244                                                <|editable_region_end|>
2245                                                ```
2246                                            "}
2247                                            .to_string(),
2248                                        )),
2249                                        tool_calls: vec![],
2250                                    },
2251                                    finish_reason: Some("stop".to_string()),
2252                                }],
2253                            })
2254                            .unwrap()
2255                            .into(),
2256                        )
2257                        .unwrap())
2258                } else {
2259                    Ok(gpui::http_client::Response::builder()
2260                        .status(401)
2261                        .body("Unauthorized".into())
2262                        .unwrap())
2263                }
2264            }
2265        }
2266    });
2267
2268    let client =
2269        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2270    cx.update(|cx| {
2271        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2272    });
2273
2274    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2275
2276    let buffer = project
2277        .update(cx, |project, cx| {
2278            let path = project
2279                .find_project_path(path!("/project/main.rs"), cx)
2280                .unwrap();
2281            project.open_buffer(path, cx)
2282        })
2283        .await
2284        .unwrap();
2285
2286    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2287    ep_store.update(cx, |ep_store, cx| {
2288        ep_store.register_buffer(&buffer, &project, cx)
2289    });
2290    cx.background_executor.run_until_parked();
2291
2292    let completion_task = ep_store.update(cx, |ep_store, cx| {
2293        ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2294        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2295        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2296    });
2297
2298    let _ = completion_task.await;
2299
2300    assert!(
2301        predict_called.load(std::sync::atomic::Ordering::SeqCst),
2302        "With custom URL, predict endpoint should be called even without authentication"
2303    );
2304}
2305
2306#[gpui::test]
2307fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2308    let buffer = cx.new(|cx| {
2309        Buffer::local(
2310            indoc! {"
2311                zero
2312                one
2313                two
2314                three
2315                four
2316                five
2317                six
2318                seven
2319                eight
2320                nine
2321                ten
2322                eleven
2323                twelve
2324                thirteen
2325                fourteen
2326                fifteen
2327                sixteen
2328                seventeen
2329                eighteen
2330                nineteen
2331                twenty
2332                twenty-one
2333                twenty-two
2334                twenty-three
2335                twenty-four
2336            "},
2337            cx,
2338        )
2339    });
2340
2341    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2342
2343    buffer.update(cx, |buffer, cx| {
2344        let point = Point::new(12, 0);
2345        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2346        let point = Point::new(8, 0);
2347        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2348    });
2349
2350    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2351
2352    let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2353
2354    assert_eq!(
2355        diff,
2356        indoc! {"
2357            @@ -6,10 +6,12 @@
2358             five
2359             six
2360             seven
2361            +FIRST INSERTION
2362             eight
2363             nine
2364             ten
2365             eleven
2366            +SECOND INSERTION
2367             twelve
2368             thirteen
2369             fourteen
2370            "}
2371    );
2372}
2373
2374#[ctor::ctor]
2375fn init_logger() {
2376    zlog::init_test();
2377}