edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
   3use client::{UserStore, test::FakeServer};
   4use clock::{FakeSystemClock, ReplicaId};
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
   8    RejectEditPredictionsBody,
   9    predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
  10};
  11use futures::{
  12    AsyncReadExt, StreamExt,
  13    channel::{mpsc, oneshot},
  14};
  15use gpui::App;
  16use gpui::{
  17    Entity, TestAppContext,
  18    http_client::{FakeHttpClient, Response},
  19};
  20use indoc::indoc;
  21use language::{Buffer, Point};
  22use lsp::LanguageServerId;
  23use parking_lot::Mutex;
  24use pretty_assertions::{assert_eq, assert_matches};
  25use project::{FakeFs, Project};
  26use serde_json::json;
  27use settings::SettingsStore;
  28use std::{path::Path, sync::Arc, time::Duration};
  29use util::{path, rel_path::rel_path};
  30use uuid::Uuid;
  31use zeta_prompt::ZetaPromptInput;
  32
  33use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  34
  35#[gpui::test]
  36async fn test_current_state(cx: &mut TestAppContext) {
  37    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  38    let fs = FakeFs::new(cx.executor());
  39    fs.insert_tree(
  40        "/root",
  41        json!({
  42            "1.txt": "Hello!\nHow\nBye\n",
  43            "2.txt": "Hola!\nComo\nAdios\n"
  44        }),
  45    )
  46    .await;
  47    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  48
  49    let buffer1 = project
  50        .update(cx, |project, cx| {
  51            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  52            project.set_active_path(Some(path.clone()), cx);
  53            project.open_buffer(path, cx)
  54        })
  55        .await
  56        .unwrap();
  57    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  58    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  59
  60    ep_store.update(cx, |ep_store, cx| {
  61        ep_store.register_project(&project, cx);
  62        ep_store.register_buffer(&buffer1, &project, cx);
  63    });
  64
  65    // Prediction for current file
  66
  67    ep_store.update(cx, |ep_store, cx| {
  68        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  69    });
  70    let (request, respond_tx) = requests.predict.next().await.unwrap();
  71
  72    respond_tx
  73        .send(model_response(
  74            &request,
  75            indoc! {r"
  76                --- a/root/1.txt
  77                +++ b/root/1.txt
  78                @@ ... @@
  79                 Hello!
  80                -How
  81                +How are you?
  82                 Bye
  83            "},
  84        ))
  85        .unwrap();
  86
  87    cx.run_until_parked();
  88
  89    ep_store.update(cx, |ep_store, cx| {
  90        let prediction = ep_store
  91            .prediction_at(&buffer1, None, &project, cx)
  92            .unwrap();
  93        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  94    });
  95
  96    ep_store.update(cx, |ep_store, _cx| {
  97        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
  98    });
  99
 100    // Prediction for diagnostic in another file
 101
 102    let diagnostic = lsp::Diagnostic {
 103        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 104        severity: Some(lsp::DiagnosticSeverity::ERROR),
 105        message: "Sentence is incomplete".to_string(),
 106        ..Default::default()
 107    };
 108
 109    project.update(cx, |project, cx| {
 110        project.lsp_store().update(cx, |lsp_store, cx| {
 111            lsp_store
 112                .update_diagnostics(
 113                    LanguageServerId(0),
 114                    lsp::PublishDiagnosticsParams {
 115                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 116                        diagnostics: vec![diagnostic],
 117                        version: None,
 118                    },
 119                    None,
 120                    language::DiagnosticSourceKind::Pushed,
 121                    &[],
 122                    cx,
 123                )
 124                .unwrap();
 125        });
 126    });
 127
 128    let (request, respond_tx) = requests.predict.next().await.unwrap();
 129    respond_tx
 130        .send(model_response(
 131            &request,
 132            indoc! {r#"
 133                --- a/root/2.txt
 134                +++ b/root/2.txt
 135                @@ ... @@
 136                 Hola!
 137                -Como
 138                +Como estas?
 139                 Adios
 140            "#},
 141        ))
 142        .unwrap();
 143    cx.run_until_parked();
 144
 145    ep_store.update(cx, |ep_store, cx| {
 146        let prediction = ep_store
 147            .prediction_at(&buffer1, None, &project, cx)
 148            .unwrap();
 149        assert_matches!(
 150            prediction,
 151            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 152        );
 153    });
 154
 155    let buffer2 = project
 156        .update(cx, |project, cx| {
 157            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 158            project.open_buffer(path, cx)
 159        })
 160        .await
 161        .unwrap();
 162
 163    ep_store.update(cx, |ep_store, cx| {
 164        let prediction = ep_store
 165            .prediction_at(&buffer2, None, &project, cx)
 166            .unwrap();
 167        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 168    });
 169}
 170
 171#[gpui::test]
 172async fn test_simple_request(cx: &mut TestAppContext) {
 173    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 174    let fs = FakeFs::new(cx.executor());
 175    fs.insert_tree(
 176        "/root",
 177        json!({
 178            "foo.md":  "Hello!\nHow\nBye\n"
 179        }),
 180    )
 181    .await;
 182    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 183
 184    let buffer = project
 185        .update(cx, |project, cx| {
 186            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 187            project.open_buffer(path, cx)
 188        })
 189        .await
 190        .unwrap();
 191    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 192    let position = snapshot.anchor_before(language::Point::new(1, 3));
 193
 194    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 195        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 196    });
 197
 198    let (request, respond_tx) = requests.predict.next().await.unwrap();
 199
 200    // TODO Put back when we have a structured request again
 201    // assert_eq!(
 202    //     request.excerpt_path.as_ref(),
 203    //     Path::new(path!("root/foo.md"))
 204    // );
 205    // assert_eq!(
 206    //     request.cursor_point,
 207    //     Point {
 208    //         line: Line(1),
 209    //         column: 3
 210    //     }
 211    // );
 212
 213    respond_tx
 214        .send(model_response(
 215            &request,
 216            indoc! { r"
 217                --- a/root/foo.md
 218                +++ b/root/foo.md
 219                @@ ... @@
 220                 Hello!
 221                -How
 222                +How are you?
 223                 Bye
 224            "},
 225        ))
 226        .unwrap();
 227
 228    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 229
 230    assert_eq!(prediction.edits.len(), 1);
 231    assert_eq!(
 232        prediction.edits[0].0.to_point(&snapshot).start,
 233        language::Point::new(1, 3)
 234    );
 235    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 236}
 237
 238#[gpui::test]
 239async fn test_request_events(cx: &mut TestAppContext) {
 240    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 241    let fs = FakeFs::new(cx.executor());
 242    fs.insert_tree(
 243        "/root",
 244        json!({
 245            "foo.md": "Hello!\n\nBye\n"
 246        }),
 247    )
 248    .await;
 249    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 250
 251    let buffer = project
 252        .update(cx, |project, cx| {
 253            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 254            project.open_buffer(path, cx)
 255        })
 256        .await
 257        .unwrap();
 258
 259    ep_store.update(cx, |ep_store, cx| {
 260        ep_store.register_buffer(&buffer, &project, cx);
 261    });
 262
 263    buffer.update(cx, |buffer, cx| {
 264        buffer.edit(vec![(7..7, "How")], None, cx);
 265    });
 266
 267    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 268    let position = snapshot.anchor_before(language::Point::new(1, 3));
 269
 270    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 271        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 272    });
 273
 274    let (request, respond_tx) = requests.predict.next().await.unwrap();
 275
 276    let prompt = prompt_from_request(&request);
 277    assert!(
 278        prompt.contains(indoc! {"
 279        --- a/root/foo.md
 280        +++ b/root/foo.md
 281        @@ -1,3 +1,3 @@
 282         Hello!
 283        -
 284        +How
 285         Bye
 286    "}),
 287        "{prompt}"
 288    );
 289
 290    respond_tx
 291        .send(model_response(
 292            &request,
 293            indoc! {r#"
 294                --- a/root/foo.md
 295                +++ b/root/foo.md
 296                @@ ... @@
 297                 Hello!
 298                -How
 299                +How are you?
 300                 Bye
 301        "#},
 302        ))
 303        .unwrap();
 304
 305    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 306
 307    assert_eq!(prediction.edits.len(), 1);
 308    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 309}
 310
 311#[gpui::test]
 312async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 313    let (ep_store, _requests) = init_test_with_fake_client(cx);
 314    let fs = FakeFs::new(cx.executor());
 315    fs.insert_tree(
 316        "/root",
 317        json!({
 318            "foo.md": "Hello!\n\nBye\n"
 319        }),
 320    )
 321    .await;
 322    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 323
 324    let buffer = project
 325        .update(cx, |project, cx| {
 326            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 327            project.open_buffer(path, cx)
 328        })
 329        .await
 330        .unwrap();
 331
 332    ep_store.update(cx, |ep_store, cx| {
 333        ep_store.register_buffer(&buffer, &project, cx);
 334    });
 335
 336    // First burst: insert "How"
 337    buffer.update(cx, |buffer, cx| {
 338        buffer.edit(vec![(7..7, "How")], None, cx);
 339    });
 340
 341    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 342    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 343    cx.run_until_parked();
 344
 345    // Second burst: append " are you?" immediately after "How" on the same line.
 346    //
 347    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 348    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 349    buffer.update(cx, |buffer, cx| {
 350        buffer.edit(vec![(10..10, " are you?")], None, cx);
 351    });
 352
 353    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 354    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 355    buffer.update(cx, |buffer, cx| {
 356        buffer.edit(vec![(19..19, "!")], None, cx);
 357    });
 358
 359    // Without time-based splitting, there is one event.
 360    let events = ep_store.update(cx, |ep_store, cx| {
 361        ep_store.edit_history_for_project(&project, cx)
 362    });
 363    assert_eq!(events.len(), 1);
 364    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 365    assert_eq!(
 366        diff.as_str(),
 367        indoc! {"
 368            @@ -1,3 +1,3 @@
 369             Hello!
 370            -
 371            +How are you?!
 372             Bye
 373        "}
 374    );
 375
 376    // With time-based splitting, there are two distinct events.
 377    let events = ep_store.update(cx, |ep_store, cx| {
 378        ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
 379    });
 380    assert_eq!(events.len(), 2);
 381    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 382    assert_eq!(
 383        diff.as_str(),
 384        indoc! {"
 385            @@ -1,3 +1,3 @@
 386             Hello!
 387            -
 388            +How
 389             Bye
 390        "}
 391    );
 392
 393    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 394    assert_eq!(
 395        diff.as_str(),
 396        indoc! {"
 397            @@ -1,3 +1,3 @@
 398             Hello!
 399            -How
 400            +How are you?!
 401             Bye
 402        "}
 403    );
 404}
 405
 406#[gpui::test]
 407async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
 408    let (ep_store, _requests) = init_test_with_fake_client(cx);
 409    let fs = FakeFs::new(cx.executor());
 410
 411    // Create a file with 30 lines to test line-based coalescing
 412    let content = (1..=30)
 413        .map(|i| format!("Line {}\n", i))
 414        .collect::<String>();
 415    fs.insert_tree(
 416        "/root",
 417        json!({
 418            "foo.md": content
 419        }),
 420    )
 421    .await;
 422    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 423
 424    let buffer = project
 425        .update(cx, |project, cx| {
 426            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 427            project.open_buffer(path, cx)
 428        })
 429        .await
 430        .unwrap();
 431
 432    ep_store.update(cx, |ep_store, cx| {
 433        ep_store.register_buffer(&buffer, &project, cx);
 434    });
 435
 436    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 437    buffer.update(cx, |buffer, cx| {
 438        let start = Point::new(10, 0).to_offset(buffer);
 439        let end = Point::new(13, 0).to_offset(buffer);
 440        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 441    });
 442
 443    let events = ep_store.update(cx, |ep_store, cx| {
 444        ep_store.edit_history_for_project(&project, cx)
 445    });
 446    assert_eq!(
 447        render_events(&events),
 448        indoc! {"
 449            @@ -8,9 +8,8 @@
 450             Line 8
 451             Line 9
 452             Line 10
 453            -Line 11
 454            -Line 12
 455            -Line 13
 456            +Middle A
 457            +Middle B
 458             Line 14
 459             Line 15
 460             Line 16
 461        "},
 462        "After first edit"
 463    );
 464
 465    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 466    // This tests that coalescing considers the START of the existing range
 467    buffer.update(cx, |buffer, cx| {
 468        let offset = Point::new(5, 0).to_offset(buffer);
 469        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 470    });
 471
 472    let events = ep_store.update(cx, |ep_store, cx| {
 473        ep_store.edit_history_for_project(&project, cx)
 474    });
 475    assert_eq!(
 476        render_events(&events),
 477        indoc! {"
 478            @@ -3,14 +3,14 @@
 479             Line 3
 480             Line 4
 481             Line 5
 482            +Above
 483             Line 6
 484             Line 7
 485             Line 8
 486             Line 9
 487             Line 10
 488            -Line 11
 489            -Line 12
 490            -Line 13
 491            +Middle A
 492            +Middle B
 493             Line 14
 494             Line 15
 495             Line 16
 496        "},
 497        "After inserting above (should coalesce)"
 498    );
 499
 500    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 501    // This tests that coalescing considers the END of the existing range
 502    buffer.update(cx, |buffer, cx| {
 503        let offset = Point::new(14, 0).to_offset(buffer);
 504        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 505    });
 506
 507    let events = ep_store.update(cx, |ep_store, cx| {
 508        ep_store.edit_history_for_project(&project, cx)
 509    });
 510    assert_eq!(
 511        render_events(&events),
 512        indoc! {"
 513            @@ -3,15 +3,16 @@
 514             Line 3
 515             Line 4
 516             Line 5
 517            +Above
 518             Line 6
 519             Line 7
 520             Line 8
 521             Line 9
 522             Line 10
 523            -Line 11
 524            -Line 12
 525            -Line 13
 526            +Middle A
 527            +Middle B
 528             Line 14
 529            +Below
 530             Line 15
 531             Line 16
 532             Line 17
 533        "},
 534        "After inserting below (should coalesce)"
 535    );
 536
 537    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 538    // This should NOT coalesce - creates a new event
 539    buffer.update(cx, |buffer, cx| {
 540        let offset = Point::new(25, 0).to_offset(buffer);
 541        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 542    });
 543
 544    let events = ep_store.update(cx, |ep_store, cx| {
 545        ep_store.edit_history_for_project(&project, cx)
 546    });
 547    assert_eq!(
 548        render_events(&events),
 549        indoc! {"
 550            @@ -3,15 +3,16 @@
 551             Line 3
 552             Line 4
 553             Line 5
 554            +Above
 555             Line 6
 556             Line 7
 557             Line 8
 558             Line 9
 559             Line 10
 560            -Line 11
 561            -Line 12
 562            -Line 13
 563            +Middle A
 564            +Middle B
 565             Line 14
 566            +Below
 567             Line 15
 568             Line 16
 569             Line 17
 570
 571            ---
 572            @@ -23,6 +23,7 @@
 573             Line 22
 574             Line 23
 575             Line 24
 576            +Far below
 577             Line 25
 578             Line 26
 579             Line 27
 580        "},
 581        "After inserting far below (should NOT coalesce)"
 582    );
 583}
 584
 585fn render_events(events: &[StoredEvent]) -> String {
 586    events
 587        .iter()
 588        .map(|e| {
 589            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 590            diff.as_str()
 591        })
 592        .collect::<Vec<_>>()
 593        .join("\n---\n")
 594}
 595
 596#[gpui::test]
 597async fn test_empty_prediction(cx: &mut TestAppContext) {
 598    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 599    let fs = FakeFs::new(cx.executor());
 600    fs.insert_tree(
 601        "/root",
 602        json!({
 603            "foo.md":  "Hello!\nHow\nBye\n"
 604        }),
 605    )
 606    .await;
 607    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 608
 609    let buffer = project
 610        .update(cx, |project, cx| {
 611            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 612            project.open_buffer(path, cx)
 613        })
 614        .await
 615        .unwrap();
 616    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 617    let position = snapshot.anchor_before(language::Point::new(1, 3));
 618
 619    ep_store.update(cx, |ep_store, cx| {
 620        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 621    });
 622
 623    let (request, respond_tx) = requests.predict.next().await.unwrap();
 624    let response = model_response(&request, "");
 625    let id = response.request_id.clone();
 626    respond_tx.send(response).unwrap();
 627
 628    cx.run_until_parked();
 629
 630    ep_store.update(cx, |ep_store, cx| {
 631        assert!(
 632            ep_store
 633                .prediction_at(&buffer, None, &project, cx)
 634                .is_none()
 635        );
 636    });
 637
 638    // prediction is reported as rejected
 639    let (reject_request, _) = requests.reject.next().await.unwrap();
 640
 641    assert_eq!(
 642        &reject_request.rejections,
 643        &[EditPredictionRejection {
 644            request_id: id,
 645            reason: EditPredictionRejectReason::Empty,
 646            was_shown: false
 647        }]
 648    );
 649}
 650
 651#[gpui::test]
 652async fn test_interpolated_empty(cx: &mut TestAppContext) {
 653    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 654    let fs = FakeFs::new(cx.executor());
 655    fs.insert_tree(
 656        "/root",
 657        json!({
 658            "foo.md":  "Hello!\nHow\nBye\n"
 659        }),
 660    )
 661    .await;
 662    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 663
 664    let buffer = project
 665        .update(cx, |project, cx| {
 666            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 667            project.open_buffer(path, cx)
 668        })
 669        .await
 670        .unwrap();
 671    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 672    let position = snapshot.anchor_before(language::Point::new(1, 3));
 673
 674    ep_store.update(cx, |ep_store, cx| {
 675        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 676    });
 677
 678    let (request, respond_tx) = requests.predict.next().await.unwrap();
 679
 680    buffer.update(cx, |buffer, cx| {
 681        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 682    });
 683
 684    let response = model_response(&request, SIMPLE_DIFF);
 685    let id = response.request_id.clone();
 686    respond_tx.send(response).unwrap();
 687
 688    cx.run_until_parked();
 689
 690    ep_store.update(cx, |ep_store, cx| {
 691        assert!(
 692            ep_store
 693                .prediction_at(&buffer, None, &project, cx)
 694                .is_none()
 695        );
 696    });
 697
 698    // prediction is reported as rejected
 699    let (reject_request, _) = requests.reject.next().await.unwrap();
 700
 701    assert_eq!(
 702        &reject_request.rejections,
 703        &[EditPredictionRejection {
 704            request_id: id,
 705            reason: EditPredictionRejectReason::InterpolatedEmpty,
 706            was_shown: false
 707        }]
 708    );
 709}
 710
 711const SIMPLE_DIFF: &str = indoc! { r"
 712    --- a/root/foo.md
 713    +++ b/root/foo.md
 714    @@ ... @@
 715     Hello!
 716    -How
 717    +How are you?
 718     Bye
 719"};
 720
 721#[gpui::test]
 722async fn test_replace_current(cx: &mut TestAppContext) {
 723    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 724    let fs = FakeFs::new(cx.executor());
 725    fs.insert_tree(
 726        "/root",
 727        json!({
 728            "foo.md":  "Hello!\nHow\nBye\n"
 729        }),
 730    )
 731    .await;
 732    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 733
 734    let buffer = project
 735        .update(cx, |project, cx| {
 736            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 737            project.open_buffer(path, cx)
 738        })
 739        .await
 740        .unwrap();
 741    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 742    let position = snapshot.anchor_before(language::Point::new(1, 3));
 743
 744    ep_store.update(cx, |ep_store, cx| {
 745        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 746    });
 747
 748    let (request, respond_tx) = requests.predict.next().await.unwrap();
 749    let first_response = model_response(&request, SIMPLE_DIFF);
 750    let first_id = first_response.request_id.clone();
 751    respond_tx.send(first_response).unwrap();
 752
 753    cx.run_until_parked();
 754
 755    ep_store.update(cx, |ep_store, cx| {
 756        assert_eq!(
 757            ep_store
 758                .prediction_at(&buffer, None, &project, cx)
 759                .unwrap()
 760                .id
 761                .0,
 762            first_id
 763        );
 764    });
 765
 766    // a second request is triggered
 767    ep_store.update(cx, |ep_store, cx| {
 768        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 769    });
 770
 771    let (request, respond_tx) = requests.predict.next().await.unwrap();
 772    let second_response = model_response(&request, SIMPLE_DIFF);
 773    let second_id = second_response.request_id.clone();
 774    respond_tx.send(second_response).unwrap();
 775
 776    cx.run_until_parked();
 777
 778    ep_store.update(cx, |ep_store, cx| {
 779        // second replaces first
 780        assert_eq!(
 781            ep_store
 782                .prediction_at(&buffer, None, &project, cx)
 783                .unwrap()
 784                .id
 785                .0,
 786            second_id
 787        );
 788    });
 789
 790    // first is reported as replaced
 791    let (reject_request, _) = requests.reject.next().await.unwrap();
 792
 793    assert_eq!(
 794        &reject_request.rejections,
 795        &[EditPredictionRejection {
 796            request_id: first_id,
 797            reason: EditPredictionRejectReason::Replaced,
 798            was_shown: false
 799        }]
 800    );
 801}
 802
 803#[gpui::test]
 804async fn test_current_preferred(cx: &mut TestAppContext) {
 805    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 806    let fs = FakeFs::new(cx.executor());
 807    fs.insert_tree(
 808        "/root",
 809        json!({
 810            "foo.md":  "Hello!\nHow\nBye\n"
 811        }),
 812    )
 813    .await;
 814    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 815
 816    let buffer = project
 817        .update(cx, |project, cx| {
 818            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 819            project.open_buffer(path, cx)
 820        })
 821        .await
 822        .unwrap();
 823    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 824    let position = snapshot.anchor_before(language::Point::new(1, 3));
 825
 826    ep_store.update(cx, |ep_store, cx| {
 827        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 828    });
 829
 830    let (request, respond_tx) = requests.predict.next().await.unwrap();
 831    let first_response = model_response(&request, SIMPLE_DIFF);
 832    let first_id = first_response.request_id.clone();
 833    respond_tx.send(first_response).unwrap();
 834
 835    cx.run_until_parked();
 836
 837    ep_store.update(cx, |ep_store, cx| {
 838        assert_eq!(
 839            ep_store
 840                .prediction_at(&buffer, None, &project, cx)
 841                .unwrap()
 842                .id
 843                .0,
 844            first_id
 845        );
 846    });
 847
 848    // a second request is triggered
 849    ep_store.update(cx, |ep_store, cx| {
 850        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 851    });
 852
 853    let (request, respond_tx) = requests.predict.next().await.unwrap();
 854    // worse than current prediction
 855    let second_response = model_response(
 856        &request,
 857        indoc! { r"
 858            --- a/root/foo.md
 859            +++ b/root/foo.md
 860            @@ ... @@
 861             Hello!
 862            -How
 863            +How are
 864             Bye
 865        "},
 866    );
 867    let second_id = second_response.request_id.clone();
 868    respond_tx.send(second_response).unwrap();
 869
 870    cx.run_until_parked();
 871
 872    ep_store.update(cx, |ep_store, cx| {
 873        // first is preferred over second
 874        assert_eq!(
 875            ep_store
 876                .prediction_at(&buffer, None, &project, cx)
 877                .unwrap()
 878                .id
 879                .0,
 880            first_id
 881        );
 882    });
 883
 884    // second is reported as rejected
 885    let (reject_request, _) = requests.reject.next().await.unwrap();
 886
 887    assert_eq!(
 888        &reject_request.rejections,
 889        &[EditPredictionRejection {
 890            request_id: second_id,
 891            reason: EditPredictionRejectReason::CurrentPreferred,
 892            was_shown: false
 893        }]
 894    );
 895}
 896
 897#[gpui::test]
 898async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
 899    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 900    let fs = FakeFs::new(cx.executor());
 901    fs.insert_tree(
 902        "/root",
 903        json!({
 904            "foo.md":  "Hello!\nHow\nBye\n"
 905        }),
 906    )
 907    .await;
 908    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 909
 910    let buffer = project
 911        .update(cx, |project, cx| {
 912            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 913            project.open_buffer(path, cx)
 914        })
 915        .await
 916        .unwrap();
 917    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 918    let position = snapshot.anchor_before(language::Point::new(1, 3));
 919
 920    // start two refresh tasks
 921    ep_store.update(cx, |ep_store, cx| {
 922        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 923    });
 924
 925    let (request1, respond_first) = requests.predict.next().await.unwrap();
 926
 927    ep_store.update(cx, |ep_store, cx| {
 928        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 929    });
 930
 931    let (request, respond_second) = requests.predict.next().await.unwrap();
 932
 933    // wait for throttle
 934    cx.run_until_parked();
 935
 936    // second responds first
 937    let second_response = model_response(&request, SIMPLE_DIFF);
 938    let second_id = second_response.request_id.clone();
 939    respond_second.send(second_response).unwrap();
 940
 941    cx.run_until_parked();
 942
 943    ep_store.update(cx, |ep_store, cx| {
 944        // current prediction is second
 945        assert_eq!(
 946            ep_store
 947                .prediction_at(&buffer, None, &project, cx)
 948                .unwrap()
 949                .id
 950                .0,
 951            second_id
 952        );
 953    });
 954
 955    let first_response = model_response(&request1, SIMPLE_DIFF);
 956    let first_id = first_response.request_id.clone();
 957    respond_first.send(first_response).unwrap();
 958
 959    cx.run_until_parked();
 960
 961    ep_store.update(cx, |ep_store, cx| {
 962        // current prediction is still second, since first was cancelled
 963        assert_eq!(
 964            ep_store
 965                .prediction_at(&buffer, None, &project, cx)
 966                .unwrap()
 967                .id
 968                .0,
 969            second_id
 970        );
 971    });
 972
 973    // first is reported as rejected
 974    let (reject_request, _) = requests.reject.next().await.unwrap();
 975
 976    cx.run_until_parked();
 977
 978    assert_eq!(
 979        &reject_request.rejections,
 980        &[EditPredictionRejection {
 981            request_id: first_id,
 982            reason: EditPredictionRejectReason::Canceled,
 983            was_shown: false
 984        }]
 985    );
 986}
 987
 988#[gpui::test]
 989async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
 990    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 991    let fs = FakeFs::new(cx.executor());
 992    fs.insert_tree(
 993        "/root",
 994        json!({
 995            "foo.md":  "Hello!\nHow\nBye\n"
 996        }),
 997    )
 998    .await;
 999    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1000
1001    let buffer = project
1002        .update(cx, |project, cx| {
1003            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1004            project.open_buffer(path, cx)
1005        })
1006        .await
1007        .unwrap();
1008    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1009    let position = snapshot.anchor_before(language::Point::new(1, 3));
1010
1011    // start two refresh tasks
1012    ep_store.update(cx, |ep_store, cx| {
1013        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1014    });
1015
1016    let (request1, respond_first) = requests.predict.next().await.unwrap();
1017
1018    ep_store.update(cx, |ep_store, cx| {
1019        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1020    });
1021
1022    let (request2, respond_second) = requests.predict.next().await.unwrap();
1023
1024    // wait for throttle, so requests are sent
1025    cx.run_until_parked();
1026
1027    ep_store.update(cx, |ep_store, cx| {
1028        // start a third request
1029        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1030
1031        // 2 are pending, so 2nd is cancelled
1032        assert_eq!(
1033            ep_store
1034                .get_or_init_project(&project, cx)
1035                .cancelled_predictions
1036                .iter()
1037                .copied()
1038                .collect::<Vec<_>>(),
1039            [1]
1040        );
1041    });
1042
1043    // wait for throttle
1044    cx.run_until_parked();
1045
1046    let (request3, respond_third) = requests.predict.next().await.unwrap();
1047
1048    let first_response = model_response(&request1, SIMPLE_DIFF);
1049    let first_id = first_response.request_id.clone();
1050    respond_first.send(first_response).unwrap();
1051
1052    cx.run_until_parked();
1053
1054    ep_store.update(cx, |ep_store, cx| {
1055        // current prediction is first
1056        assert_eq!(
1057            ep_store
1058                .prediction_at(&buffer, None, &project, cx)
1059                .unwrap()
1060                .id
1061                .0,
1062            first_id
1063        );
1064    });
1065
1066    let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1067    let cancelled_id = cancelled_response.request_id.clone();
1068    respond_second.send(cancelled_response).unwrap();
1069
1070    cx.run_until_parked();
1071
1072    ep_store.update(cx, |ep_store, cx| {
1073        // current prediction is still first, since second was cancelled
1074        assert_eq!(
1075            ep_store
1076                .prediction_at(&buffer, None, &project, cx)
1077                .unwrap()
1078                .id
1079                .0,
1080            first_id
1081        );
1082    });
1083
1084    let third_response = model_response(&request3, SIMPLE_DIFF);
1085    let third_response_id = third_response.request_id.clone();
1086    respond_third.send(third_response).unwrap();
1087
1088    cx.run_until_parked();
1089
1090    ep_store.update(cx, |ep_store, cx| {
1091        // third completes and replaces first
1092        assert_eq!(
1093            ep_store
1094                .prediction_at(&buffer, None, &project, cx)
1095                .unwrap()
1096                .id
1097                .0,
1098            third_response_id
1099        );
1100    });
1101
1102    // second is reported as rejected
1103    let (reject_request, _) = requests.reject.next().await.unwrap();
1104
1105    cx.run_until_parked();
1106
1107    assert_eq!(
1108        &reject_request.rejections,
1109        &[
1110            EditPredictionRejection {
1111                request_id: cancelled_id,
1112                reason: EditPredictionRejectReason::Canceled,
1113                was_shown: false
1114            },
1115            EditPredictionRejection {
1116                request_id: first_id,
1117                reason: EditPredictionRejectReason::Replaced,
1118                was_shown: false
1119            }
1120        ]
1121    );
1122}
1123
1124#[gpui::test]
1125async fn test_rejections_flushing(cx: &mut TestAppContext) {
1126    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1127
1128    ep_store.update(cx, |ep_store, _cx| {
1129        ep_store.reject_prediction(
1130            EditPredictionId("test-1".into()),
1131            EditPredictionRejectReason::Discarded,
1132            false,
1133        );
1134        ep_store.reject_prediction(
1135            EditPredictionId("test-2".into()),
1136            EditPredictionRejectReason::Canceled,
1137            true,
1138        );
1139    });
1140
1141    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1142    cx.run_until_parked();
1143
1144    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1145    respond_tx.send(()).unwrap();
1146
1147    // batched
1148    assert_eq!(reject_request.rejections.len(), 2);
1149    assert_eq!(
1150        reject_request.rejections[0],
1151        EditPredictionRejection {
1152            request_id: "test-1".to_string(),
1153            reason: EditPredictionRejectReason::Discarded,
1154            was_shown: false
1155        }
1156    );
1157    assert_eq!(
1158        reject_request.rejections[1],
1159        EditPredictionRejection {
1160            request_id: "test-2".to_string(),
1161            reason: EditPredictionRejectReason::Canceled,
1162            was_shown: true
1163        }
1164    );
1165
1166    // Reaching batch size limit sends without debounce
1167    ep_store.update(cx, |ep_store, _cx| {
1168        for i in 0..70 {
1169            ep_store.reject_prediction(
1170                EditPredictionId(format!("batch-{}", i).into()),
1171                EditPredictionRejectReason::Discarded,
1172                false,
1173            );
1174        }
1175    });
1176
1177    // First MAX/2 items are sent immediately
1178    cx.run_until_parked();
1179    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1180    respond_tx.send(()).unwrap();
1181
1182    assert_eq!(reject_request.rejections.len(), 50);
1183    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1184    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1185
1186    // Remaining items are debounced with the next batch
1187    cx.executor().advance_clock(Duration::from_secs(15));
1188    cx.run_until_parked();
1189
1190    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1191    respond_tx.send(()).unwrap();
1192
1193    assert_eq!(reject_request.rejections.len(), 20);
1194    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1195    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1196
1197    // Request failure
1198    ep_store.update(cx, |ep_store, _cx| {
1199        ep_store.reject_prediction(
1200            EditPredictionId("retry-1".into()),
1201            EditPredictionRejectReason::Discarded,
1202            false,
1203        );
1204    });
1205
1206    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1207    cx.run_until_parked();
1208
1209    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1210    assert_eq!(reject_request.rejections.len(), 1);
1211    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1212    // Simulate failure
1213    drop(_respond_tx);
1214
1215    // Add another rejection
1216    ep_store.update(cx, |ep_store, _cx| {
1217        ep_store.reject_prediction(
1218            EditPredictionId("retry-2".into()),
1219            EditPredictionRejectReason::Discarded,
1220            false,
1221        );
1222    });
1223
1224    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1225    cx.run_until_parked();
1226
1227    // Retry should include both the failed item and the new one
1228    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1229    respond_tx.send(()).unwrap();
1230
1231    assert_eq!(reject_request.rejections.len(), 2);
1232    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1233    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1234}
1235
1236// Skipped until we start including diagnostics in prompt
1237// #[gpui::test]
1238// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1239//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1240//     let fs = FakeFs::new(cx.executor());
1241//     fs.insert_tree(
1242//         "/root",
1243//         json!({
1244//             "foo.md": "Hello!\nBye"
1245//         }),
1246//     )
1247//     .await;
1248//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1249
1250//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1251//     let diagnostic = lsp::Diagnostic {
1252//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1253//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1254//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1255//         ..Default::default()
1256//     };
1257
1258//     project.update(cx, |project, cx| {
1259//         project.lsp_store().update(cx, |lsp_store, cx| {
1260//             // Create some diagnostics
1261//             lsp_store
1262//                 .update_diagnostics(
1263//                     LanguageServerId(0),
1264//                     lsp::PublishDiagnosticsParams {
1265//                         uri: path_to_buffer_uri.clone(),
1266//                         diagnostics: vec![diagnostic],
1267//                         version: None,
1268//                     },
1269//                     None,
1270//                     language::DiagnosticSourceKind::Pushed,
1271//                     &[],
1272//                     cx,
1273//                 )
1274//                 .unwrap();
1275//         });
1276//     });
1277
1278//     let buffer = project
1279//         .update(cx, |project, cx| {
1280//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1281//             project.open_buffer(path, cx)
1282//         })
1283//         .await
1284//         .unwrap();
1285
1286//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1287//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1288
1289//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1290//         ep_store.request_prediction(&project, &buffer, position, cx)
1291//     });
1292
1293//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1294
1295//     assert_eq!(request.diagnostic_groups.len(), 1);
1296//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1297//         .unwrap();
1298//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1299//     assert_eq!(
1300//         value,
1301//         json!({
1302//             "entries": [{
1303//                 "range": {
1304//                     "start": 8,
1305//                     "end": 10
1306//                 },
1307//                 "diagnostic": {
1308//                     "source": null,
1309//                     "code": null,
1310//                     "code_description": null,
1311//                     "severity": 1,
1312//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1313//                     "markdown": null,
1314//                     "group_id": 0,
1315//                     "is_primary": true,
1316//                     "is_disk_based": false,
1317//                     "is_unnecessary": false,
1318//                     "source_kind": "Pushed",
1319//                     "data": null,
1320//                     "underline": true
1321//                 }
1322//             }],
1323//             "primary_ix": 0
1324//         })
1325//     );
1326// }
1327
1328// Generate a model response that would apply the given diff to the active file.
1329fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1330    let excerpt =
1331        request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1332    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1333
1334    PredictEditsV3Response {
1335        request_id: Uuid::new_v4().to_string(),
1336        output: new_excerpt,
1337    }
1338}
1339
1340fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1341    zeta_prompt::format_zeta_prompt(&request.input, request.prompt_version)
1342}
1343
1344struct RequestChannels {
1345    predict: mpsc::UnboundedReceiver<(
1346        PredictEditsV3Request,
1347        oneshot::Sender<PredictEditsV3Response>,
1348    )>,
1349    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1350}
1351
1352fn init_test_with_fake_client(
1353    cx: &mut TestAppContext,
1354) -> (Entity<EditPredictionStore>, RequestChannels) {
1355    cx.update(move |cx| {
1356        let settings_store = SettingsStore::test(cx);
1357        cx.set_global(settings_store);
1358        zlog::init_test();
1359
1360        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1361        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1362
1363        let http_client = FakeHttpClient::create({
1364            move |req| {
1365                let uri = req.uri().path().to_string();
1366                let mut body = req.into_body();
1367                let predict_req_tx = predict_req_tx.clone();
1368                let reject_req_tx = reject_req_tx.clone();
1369                async move {
1370                    let resp = match uri.as_str() {
1371                        "/client/llm_tokens" => serde_json::to_string(&json!({
1372                            "token": "test"
1373                        }))
1374                        .unwrap(),
1375                        "/predict_edits/v3" => {
1376                            let mut buf = Vec::new();
1377                            body.read_to_end(&mut buf).await.ok();
1378                            let req = serde_json::from_slice(&buf).unwrap();
1379
1380                            let (res_tx, res_rx) = oneshot::channel();
1381                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1382                            serde_json::to_string(&res_rx.await?).unwrap()
1383                        }
1384                        "/predict_edits/reject" => {
1385                            let mut buf = Vec::new();
1386                            body.read_to_end(&mut buf).await.ok();
1387                            let req = serde_json::from_slice(&buf).unwrap();
1388
1389                            let (res_tx, res_rx) = oneshot::channel();
1390                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1391                            serde_json::to_string(&res_rx.await?).unwrap()
1392                        }
1393                        _ => {
1394                            panic!("Unexpected path: {}", uri)
1395                        }
1396                    };
1397
1398                    Ok(Response::builder().body(resp.into()).unwrap())
1399                }
1400            }
1401        });
1402
1403        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1404        client.cloud_client().set_credentials(1, "test".into());
1405
1406        language_model::init(client.clone(), cx);
1407
1408        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1409        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1410
1411        (
1412            ep_store,
1413            RequestChannels {
1414                predict: predict_req_rx,
1415                reject: reject_req_rx,
1416            },
1417        )
1418    })
1419}
1420
1421const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1422
1423#[gpui::test]
1424async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1425    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1426    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1427        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1428    });
1429
1430    let edit_preview = cx
1431        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1432        .await;
1433
1434    let prediction = EditPrediction {
1435        edits,
1436        cursor_position: None,
1437        edit_preview,
1438        buffer: buffer.clone(),
1439        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1440        id: EditPredictionId("the-id".into()),
1441        inputs: ZetaPromptInput {
1442            events: Default::default(),
1443            related_files: Default::default(),
1444            cursor_path: Path::new("").into(),
1445            cursor_excerpt: "".into(),
1446            editable_range_in_excerpt: 0..0,
1447            cursor_offset_in_excerpt: 0,
1448            excerpt_start_row: None,
1449        },
1450        buffer_snapshotted_at: Instant::now(),
1451        response_received_at: Instant::now(),
1452    };
1453
1454    cx.update(|cx| {
1455        assert_eq!(
1456            from_completion_edits(
1457                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1458                &buffer,
1459                cx
1460            ),
1461            vec![(2..5, "REM".into()), (9..11, "".into())]
1462        );
1463
1464        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1465        assert_eq!(
1466            from_completion_edits(
1467                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1468                &buffer,
1469                cx
1470            ),
1471            vec![(2..2, "REM".into()), (6..8, "".into())]
1472        );
1473
1474        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1475        assert_eq!(
1476            from_completion_edits(
1477                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1478                &buffer,
1479                cx
1480            ),
1481            vec![(2..5, "REM".into()), (9..11, "".into())]
1482        );
1483
1484        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1485        assert_eq!(
1486            from_completion_edits(
1487                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1488                &buffer,
1489                cx
1490            ),
1491            vec![(3..3, "EM".into()), (7..9, "".into())]
1492        );
1493
1494        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1495        assert_eq!(
1496            from_completion_edits(
1497                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1498                &buffer,
1499                cx
1500            ),
1501            vec![(4..4, "M".into()), (8..10, "".into())]
1502        );
1503
1504        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1505        assert_eq!(
1506            from_completion_edits(
1507                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1508                &buffer,
1509                cx
1510            ),
1511            vec![(9..11, "".into())]
1512        );
1513
1514        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1515        assert_eq!(
1516            from_completion_edits(
1517                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1518                &buffer,
1519                cx
1520            ),
1521            vec![(4..4, "M".into()), (8..10, "".into())]
1522        );
1523
1524        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1525        assert_eq!(
1526            from_completion_edits(
1527                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1528                &buffer,
1529                cx
1530            ),
1531            vec![(4..4, "M".into())]
1532        );
1533
1534        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1535        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1536    })
1537}
1538
1539#[gpui::test]
1540async fn test_clean_up_diff(cx: &mut TestAppContext) {
1541    init_test(cx);
1542
1543    assert_eq!(
1544        apply_edit_prediction(
1545            indoc! {"
1546                    fn main() {
1547                        let word_1 = \"lorem\";
1548                        let range = word.len()..word.len();
1549                    }
1550                "},
1551            indoc! {"
1552                    <|editable_region_start|>
1553                    fn main() {
1554                        let word_1 = \"lorem\";
1555                        let range = word_1.len()..word_1.len();
1556                    }
1557
1558                    <|editable_region_end|>
1559                "},
1560            cx,
1561        )
1562        .await,
1563        indoc! {"
1564                fn main() {
1565                    let word_1 = \"lorem\";
1566                    let range = word_1.len()..word_1.len();
1567                }
1568            "},
1569    );
1570
1571    assert_eq!(
1572        apply_edit_prediction(
1573            indoc! {"
1574                    fn main() {
1575                        let story = \"the quick\"
1576                    }
1577                "},
1578            indoc! {"
1579                    <|editable_region_start|>
1580                    fn main() {
1581                        let story = \"the quick brown fox jumps over the lazy dog\";
1582                    }
1583
1584                    <|editable_region_end|>
1585                "},
1586            cx,
1587        )
1588        .await,
1589        indoc! {"
1590                fn main() {
1591                    let story = \"the quick brown fox jumps over the lazy dog\";
1592                }
1593            "},
1594    );
1595}
1596
1597#[gpui::test]
1598async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1599    init_test(cx);
1600
1601    let buffer_content = "lorem\n";
1602    let completion_response = indoc! {"
1603            ```animals.js
1604            <|start_of_file|>
1605            <|editable_region_start|>
1606            lorem
1607            ipsum
1608            <|editable_region_end|>
1609            ```"};
1610
1611    assert_eq!(
1612        apply_edit_prediction(buffer_content, completion_response, cx).await,
1613        "lorem\nipsum"
1614    );
1615}
1616
1617#[gpui::test]
1618async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1619    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1620    // When the buffer ends without a trailing newline, but the model returns output
1621    // with a trailing newline, zeta2 should normalize both sides before diffing
1622    // so no spurious newline is inserted.
1623    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1624    let fs = FakeFs::new(cx.executor());
1625
1626    // Single line buffer with no trailing newline
1627    fs.insert_tree(
1628        "/root",
1629        json!({
1630            "foo.txt": "hello"
1631        }),
1632    )
1633    .await;
1634    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1635
1636    let buffer = project
1637        .update(cx, |project, cx| {
1638            let path = project
1639                .find_project_path(path!("root/foo.txt"), cx)
1640                .unwrap();
1641            project.open_buffer(path, cx)
1642        })
1643        .await
1644        .unwrap();
1645
1646    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1647    let position = snapshot.anchor_before(language::Point::new(0, 5));
1648
1649    ep_store.update(cx, |ep_store, cx| {
1650        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1651    });
1652
1653    let (_request, respond_tx) = requests.predict.next().await.unwrap();
1654
1655    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1656    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1657    let response = PredictEditsV3Response {
1658        request_id: Uuid::new_v4().to_string(),
1659        output: "hello world\n".to_string(),
1660    };
1661    respond_tx.send(response).unwrap();
1662
1663    cx.run_until_parked();
1664
1665    // The prediction should insert " world" without adding a newline
1666    ep_store.update(cx, |ep_store, cx| {
1667        let prediction = ep_store
1668            .prediction_at(&buffer, None, &project, cx)
1669            .expect("should have prediction");
1670        let edits: Vec<_> = prediction
1671            .edits
1672            .iter()
1673            .map(|(range, text)| {
1674                let snapshot = buffer.read(cx).snapshot();
1675                (range.to_offset(&snapshot), text.clone())
1676            })
1677            .collect();
1678        assert_eq!(edits, vec![(5..5, " world".into())]);
1679    });
1680}
1681
1682#[gpui::test]
1683async fn test_can_collect_data(cx: &mut TestAppContext) {
1684    init_test(cx);
1685
1686    let fs = project::FakeFs::new(cx.executor());
1687    fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1688        .await;
1689
1690    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1691    let buffer = project
1692        .update(cx, |project, cx| {
1693            project.open_local_buffer(path!("/project/src/main.rs"), cx)
1694        })
1695        .await
1696        .unwrap();
1697
1698    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1699    ep_store.update(cx, |ep_store, _cx| {
1700        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1701    });
1702
1703    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1704    assert_eq!(
1705        captured_request.lock().clone().unwrap().can_collect_data,
1706        true
1707    );
1708
1709    ep_store.update(cx, |ep_store, _cx| {
1710        ep_store.data_collection_choice = DataCollectionChoice::Disabled
1711    });
1712
1713    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1714    assert_eq!(
1715        captured_request.lock().clone().unwrap().can_collect_data,
1716        false
1717    );
1718}
1719
1720#[gpui::test]
1721async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1722    init_test(cx);
1723
1724    let fs = project::FakeFs::new(cx.executor());
1725    let project = Project::test(fs.clone(), [], cx).await;
1726
1727    let buffer = cx.new(|_cx| {
1728        Buffer::remote(
1729            language::BufferId::new(1).unwrap(),
1730            ReplicaId::new(1),
1731            language::Capability::ReadWrite,
1732            "fn main() {\n    println!(\"Hello\");\n}",
1733        )
1734    });
1735
1736    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1737    ep_store.update(cx, |ep_store, _cx| {
1738        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1739    });
1740
1741    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1742    assert_eq!(
1743        captured_request.lock().clone().unwrap().can_collect_data,
1744        false
1745    );
1746}
1747
1748#[gpui::test]
1749async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1750    init_test(cx);
1751
1752    let fs = project::FakeFs::new(cx.executor());
1753    fs.insert_tree(
1754        path!("/project"),
1755        json!({
1756            "LICENSE": BSD_0_TXT,
1757            ".env": "SECRET_KEY=secret"
1758        }),
1759    )
1760    .await;
1761
1762    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1763    let buffer = project
1764        .update(cx, |project, cx| {
1765            project.open_local_buffer("/project/.env", cx)
1766        })
1767        .await
1768        .unwrap();
1769
1770    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1771    ep_store.update(cx, |ep_store, _cx| {
1772        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1773    });
1774
1775    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1776    assert_eq!(
1777        captured_request.lock().clone().unwrap().can_collect_data,
1778        false
1779    );
1780}
1781
1782#[gpui::test]
1783async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1784    init_test(cx);
1785
1786    let fs = project::FakeFs::new(cx.executor());
1787    let project = Project::test(fs.clone(), [], cx).await;
1788    let buffer = cx.new(|cx| Buffer::local("", cx));
1789
1790    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1791    ep_store.update(cx, |ep_store, _cx| {
1792        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1793    });
1794
1795    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1796    assert_eq!(
1797        captured_request.lock().clone().unwrap().can_collect_data,
1798        false
1799    );
1800}
1801
1802#[gpui::test]
1803async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1804    init_test(cx);
1805
1806    let fs = project::FakeFs::new(cx.executor());
1807    fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1808        .await;
1809
1810    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1811    let buffer = project
1812        .update(cx, |project, cx| {
1813            project.open_local_buffer("/project/main.rs", cx)
1814        })
1815        .await
1816        .unwrap();
1817
1818    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1819    ep_store.update(cx, |ep_store, _cx| {
1820        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1821    });
1822
1823    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1824    assert_eq!(
1825        captured_request.lock().clone().unwrap().can_collect_data,
1826        false
1827    );
1828}
1829
1830#[gpui::test]
1831async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1832    init_test(cx);
1833
1834    let fs = project::FakeFs::new(cx.executor());
1835    fs.insert_tree(
1836        path!("/open_source_worktree"),
1837        json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1838    )
1839    .await;
1840    fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1841        .await;
1842
1843    let project = Project::test(
1844        fs.clone(),
1845        [
1846            path!("/open_source_worktree").as_ref(),
1847            path!("/closed_source_worktree").as_ref(),
1848        ],
1849        cx,
1850    )
1851    .await;
1852    let buffer = project
1853        .update(cx, |project, cx| {
1854            project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1855        })
1856        .await
1857        .unwrap();
1858
1859    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1860    ep_store.update(cx, |ep_store, _cx| {
1861        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1862    });
1863
1864    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1865    assert_eq!(
1866        captured_request.lock().clone().unwrap().can_collect_data,
1867        true
1868    );
1869
1870    let closed_source_file = project
1871        .update(cx, |project, cx| {
1872            let worktree2 = project
1873                .worktree_for_root_name("closed_source_worktree", cx)
1874                .unwrap();
1875            worktree2.update(cx, |worktree2, cx| {
1876                worktree2.load_file(rel_path("main.rs"), cx)
1877            })
1878        })
1879        .await
1880        .unwrap()
1881        .file;
1882
1883    buffer.update(cx, |buffer, cx| {
1884        buffer.file_updated(closed_source_file, cx);
1885    });
1886
1887    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1888    assert_eq!(
1889        captured_request.lock().clone().unwrap().can_collect_data,
1890        false
1891    );
1892}
1893
1894#[gpui::test]
1895async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1896    init_test(cx);
1897
1898    let fs = project::FakeFs::new(cx.executor());
1899    fs.insert_tree(
1900        path!("/worktree1"),
1901        json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1902    )
1903    .await;
1904    fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1905        .await;
1906
1907    let project = Project::test(
1908        fs.clone(),
1909        [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1910        cx,
1911    )
1912    .await;
1913    let buffer = project
1914        .update(cx, |project, cx| {
1915            project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1916        })
1917        .await
1918        .unwrap();
1919    let private_buffer = project
1920        .update(cx, |project, cx| {
1921            project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1922        })
1923        .await
1924        .unwrap();
1925
1926    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1927    ep_store.update(cx, |ep_store, _cx| {
1928        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1929    });
1930
1931    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1932    assert_eq!(
1933        captured_request.lock().clone().unwrap().can_collect_data,
1934        true
1935    );
1936
1937    // this has a side effect of registering the buffer to watch for edits
1938    run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1939    assert_eq!(
1940        captured_request.lock().clone().unwrap().can_collect_data,
1941        false
1942    );
1943
1944    private_buffer.update(cx, |private_buffer, cx| {
1945        private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1946    });
1947
1948    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1949    assert_eq!(
1950        captured_request.lock().clone().unwrap().can_collect_data,
1951        false
1952    );
1953
1954    // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1955    // included
1956    buffer.update(cx, |buffer, cx| {
1957        buffer.edit(
1958            [(
1959                0..0,
1960                " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1961            )],
1962            None,
1963            cx,
1964        );
1965    });
1966
1967    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1968    assert_eq!(
1969        captured_request.lock().clone().unwrap().can_collect_data,
1970        true
1971    );
1972}
1973
1974fn init_test(cx: &mut TestAppContext) {
1975    cx.update(|cx| {
1976        let settings_store = SettingsStore::test(cx);
1977        cx.set_global(settings_store);
1978    });
1979}
1980
1981async fn apply_edit_prediction(
1982    buffer_content: &str,
1983    completion_response: &str,
1984    cx: &mut TestAppContext,
1985) -> String {
1986    let fs = project::FakeFs::new(cx.executor());
1987    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1988    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1989    let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1990    *response.lock() = completion_response.to_string();
1991    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1992    buffer.update(cx, |buffer, cx| {
1993        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1994    });
1995    buffer.read_with(cx, |buffer, _| buffer.text())
1996}
1997
1998async fn run_edit_prediction(
1999    buffer: &Entity<Buffer>,
2000    project: &Entity<Project>,
2001    ep_store: &Entity<EditPredictionStore>,
2002    cx: &mut TestAppContext,
2003) -> EditPrediction {
2004    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2005    ep_store.update(cx, |ep_store, cx| {
2006        ep_store.register_buffer(buffer, &project, cx)
2007    });
2008    cx.background_executor.run_until_parked();
2009    let prediction_task = ep_store.update(cx, |ep_store, cx| {
2010        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2011    });
2012    prediction_task.await.unwrap().unwrap().prediction.unwrap()
2013}
2014
2015async fn make_test_ep_store(
2016    project: &Entity<Project>,
2017    cx: &mut TestAppContext,
2018) -> (
2019    Entity<EditPredictionStore>,
2020    Arc<Mutex<Option<PredictEditsBody>>>,
2021    Arc<Mutex<String>>,
2022) {
2023    let default_response = indoc! {"
2024            ```main.rs
2025            <|start_of_file|>
2026            <|editable_region_start|>
2027            hello world
2028            <|editable_region_end|>
2029            ```"
2030    };
2031    let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2032    let completion_response: Arc<Mutex<String>> =
2033        Arc::new(Mutex::new(default_response.to_string()));
2034    let http_client = FakeHttpClient::create({
2035        let captured_request = captured_request.clone();
2036        let completion_response = completion_response.clone();
2037        let mut next_request_id = 0;
2038        move |req| {
2039            let captured_request = captured_request.clone();
2040            let completion_response = completion_response.clone();
2041            async move {
2042                match (req.method(), req.uri().path()) {
2043                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2044                        .status(200)
2045                        .body(
2046                            serde_json::to_string(&CreateLlmTokenResponse {
2047                                token: LlmToken("the-llm-token".to_string()),
2048                            })
2049                            .unwrap()
2050                            .into(),
2051                        )
2052                        .unwrap()),
2053                    (&Method::POST, "/predict_edits/v2") => {
2054                        let mut request_body = String::new();
2055                        req.into_body().read_to_string(&mut request_body).await?;
2056                        *captured_request.lock() =
2057                            Some(serde_json::from_str(&request_body).unwrap());
2058                        next_request_id += 1;
2059                        Ok(http_client::Response::builder()
2060                            .status(200)
2061                            .body(
2062                                serde_json::to_string(&PredictEditsResponse {
2063                                    request_id: format!("request-{next_request_id}"),
2064                                    output_excerpt: completion_response.lock().clone(),
2065                                })
2066                                .unwrap()
2067                                .into(),
2068                            )
2069                            .unwrap())
2070                    }
2071                    _ => Ok(http_client::Response::builder()
2072                        .status(404)
2073                        .body("Not Found".into())
2074                        .unwrap()),
2075                }
2076            }
2077        }
2078    });
2079
2080    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2081    cx.update(|cx| {
2082        RefreshLlmTokenListener::register(client.clone(), cx);
2083    });
2084    let _server = FakeServer::for_client(42, &client, cx).await;
2085
2086    let ep_store = cx.new(|cx| {
2087        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2088        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2089
2090        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2091        for worktree in worktrees {
2092            let worktree_id = worktree.read(cx).id();
2093            ep_store
2094                .get_or_init_project(project, cx)
2095                .license_detection_watchers
2096                .entry(worktree_id)
2097                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2098        }
2099
2100        ep_store
2101    });
2102
2103    (ep_store, captured_request, completion_response)
2104}
2105
2106fn to_completion_edits(
2107    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2108    buffer: &Entity<Buffer>,
2109    cx: &App,
2110) -> Vec<(Range<Anchor>, Arc<str>)> {
2111    let buffer = buffer.read(cx);
2112    iterator
2113        .into_iter()
2114        .map(|(range, text)| {
2115            (
2116                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2117                text,
2118            )
2119        })
2120        .collect()
2121}
2122
2123fn from_completion_edits(
2124    editor_edits: &[(Range<Anchor>, Arc<str>)],
2125    buffer: &Entity<Buffer>,
2126    cx: &App,
2127) -> Vec<(Range<usize>, Arc<str>)> {
2128    let buffer = buffer.read(cx);
2129    editor_edits
2130        .iter()
2131        .map(|(range, text)| {
2132            (
2133                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2134                text.clone(),
2135            )
2136        })
2137        .collect()
2138}
2139
2140#[gpui::test]
2141async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2142    init_test(cx);
2143
2144    let fs = FakeFs::new(cx.executor());
2145    fs.insert_tree(
2146        "/project",
2147        serde_json::json!({
2148            "main.rs": "fn main() {\n    \n}\n"
2149        }),
2150    )
2151    .await;
2152
2153    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2154
2155    let http_client = FakeHttpClient::create(|_req| async move {
2156        Ok(gpui::http_client::Response::builder()
2157            .status(401)
2158            .body("Unauthorized".into())
2159            .unwrap())
2160    });
2161
2162    let client =
2163        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2164    cx.update(|cx| {
2165        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2166    });
2167
2168    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2169
2170    let buffer = project
2171        .update(cx, |project, cx| {
2172            let path = project
2173                .find_project_path(path!("/project/main.rs"), cx)
2174                .unwrap();
2175            project.open_buffer(path, cx)
2176        })
2177        .await
2178        .unwrap();
2179
2180    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2181    ep_store.update(cx, |ep_store, cx| {
2182        ep_store.register_buffer(&buffer, &project, cx)
2183    });
2184    cx.background_executor.run_until_parked();
2185
2186    let completion_task = ep_store.update(cx, |ep_store, cx| {
2187        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2188        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2189    });
2190
2191    let result = completion_task.await;
2192    assert!(
2193        result.is_err(),
2194        "Without authentication and without custom URL, prediction should fail"
2195    );
2196}
2197
2198#[gpui::test]
2199async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2200    init_test(cx);
2201
2202    let fs = FakeFs::new(cx.executor());
2203    fs.insert_tree(
2204        "/project",
2205        serde_json::json!({
2206            "main.rs": "fn main() {\n    \n}\n"
2207        }),
2208    )
2209    .await;
2210
2211    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2212
2213    let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2214    let predict_called_clone = predict_called.clone();
2215
2216    let http_client = FakeHttpClient::create({
2217        move |req| {
2218            let uri = req.uri().path().to_string();
2219            let predict_called = predict_called_clone.clone();
2220            async move {
2221                if uri.contains("predict") {
2222                    predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2223                    Ok(gpui::http_client::Response::builder()
2224                        .body(
2225                            serde_json::to_string(&open_ai::Response {
2226                                id: "test-123".to_string(),
2227                                object: "chat.completion".to_string(),
2228                                created: 0,
2229                                model: "test".to_string(),
2230                                usage: open_ai::Usage {
2231                                    prompt_tokens: 0,
2232                                    completion_tokens: 0,
2233                                    total_tokens: 0,
2234                                },
2235                                choices: vec![open_ai::Choice {
2236                                    index: 0,
2237                                    message: open_ai::RequestMessage::Assistant {
2238                                        content: Some(open_ai::MessageContent::Plain(
2239                                            indoc! {"
2240                                                ```main.rs
2241                                                <|start_of_file|>
2242                                                <|editable_region_start|>
2243                                                fn main() {
2244                                                    println!(\"Hello, world!\");
2245                                                }
2246                                                <|editable_region_end|>
2247                                                ```
2248                                            "}
2249                                            .to_string(),
2250                                        )),
2251                                        tool_calls: vec![],
2252                                    },
2253                                    finish_reason: Some("stop".to_string()),
2254                                }],
2255                            })
2256                            .unwrap()
2257                            .into(),
2258                        )
2259                        .unwrap())
2260                } else {
2261                    Ok(gpui::http_client::Response::builder()
2262                        .status(401)
2263                        .body("Unauthorized".into())
2264                        .unwrap())
2265                }
2266            }
2267        }
2268    });
2269
2270    let client =
2271        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2272    cx.update(|cx| {
2273        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2274    });
2275
2276    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2277
2278    let buffer = project
2279        .update(cx, |project, cx| {
2280            let path = project
2281                .find_project_path(path!("/project/main.rs"), cx)
2282                .unwrap();
2283            project.open_buffer(path, cx)
2284        })
2285        .await
2286        .unwrap();
2287
2288    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2289    ep_store.update(cx, |ep_store, cx| {
2290        ep_store.register_buffer(&buffer, &project, cx)
2291    });
2292    cx.background_executor.run_until_parked();
2293
2294    let completion_task = ep_store.update(cx, |ep_store, cx| {
2295        ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2296        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2297        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2298    });
2299
2300    let _ = completion_task.await;
2301
2302    assert!(
2303        predict_called.load(std::sync::atomic::Ordering::SeqCst),
2304        "With custom URL, predict endpoint should be called even without authentication"
2305    );
2306}
2307
2308#[gpui::test]
2309fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2310    let buffer = cx.new(|cx| {
2311        Buffer::local(
2312            indoc! {"
2313                zero
2314                one
2315                two
2316                three
2317                four
2318                five
2319                six
2320                seven
2321                eight
2322                nine
2323                ten
2324                eleven
2325                twelve
2326                thirteen
2327                fourteen
2328                fifteen
2329                sixteen
2330                seventeen
2331                eighteen
2332                nineteen
2333                twenty
2334                twenty-one
2335                twenty-two
2336                twenty-three
2337                twenty-four
2338            "},
2339            cx,
2340        )
2341    });
2342
2343    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2344
2345    buffer.update(cx, |buffer, cx| {
2346        let point = Point::new(12, 0);
2347        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2348        let point = Point::new(8, 0);
2349        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2350    });
2351
2352    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2353
2354    let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2355
2356    assert_eq!(
2357        diff,
2358        indoc! {"
2359            @@ -6,10 +6,12 @@
2360             five
2361             six
2362             seven
2363            +FIRST INSERTION
2364             eight
2365             nine
2366             ten
2367             eleven
2368            +SECOND INSERTION
2369             twelve
2370             thirteen
2371             fourteen
2372            "}
2373    );
2374}
2375
2376#[ctor::ctor]
2377fn init_logger() {
2378    zlog::init_test();
2379}