edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
   3use client::{UserStore, test::FakeServer};
   4use clock::{FakeSystemClock, ReplicaId};
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
   8    RejectEditPredictionsBody,
   9    predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
  10};
  11use futures::{
  12    AsyncReadExt, StreamExt,
  13    channel::{mpsc, oneshot},
  14};
  15use gpui::{
  16    Entity, TestAppContext,
  17    http_client::{FakeHttpClient, Response},
  18};
  19use indoc::indoc;
  20use language::Point;
  21use lsp::LanguageServerId;
  22use parking_lot::Mutex;
  23use pretty_assertions::{assert_eq, assert_matches};
  24use project::{FakeFs, Project};
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{path::Path, sync::Arc, time::Duration};
  28use util::{path, rel_path::rel_path};
  29use uuid::Uuid;
  30use zeta_prompt::ZetaPromptInput;
  31
  32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  33
  34#[gpui::test]
  35async fn test_current_state(cx: &mut TestAppContext) {
  36    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  37    let fs = FakeFs::new(cx.executor());
  38    fs.insert_tree(
  39        "/root",
  40        json!({
  41            "1.txt": "Hello!\nHow\nBye\n",
  42            "2.txt": "Hola!\nComo\nAdios\n"
  43        }),
  44    )
  45    .await;
  46    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  47
  48    let buffer1 = project
  49        .update(cx, |project, cx| {
  50            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  51            project.set_active_path(Some(path.clone()), cx);
  52            project.open_buffer(path, cx)
  53        })
  54        .await
  55        .unwrap();
  56    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  57    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  58
  59    ep_store.update(cx, |ep_store, cx| {
  60        ep_store.register_project(&project, cx);
  61        ep_store.register_buffer(&buffer1, &project, cx);
  62    });
  63
  64    // Prediction for current file
  65
  66    ep_store.update(cx, |ep_store, cx| {
  67        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  68    });
  69    let (request, respond_tx) = requests.predict.next().await.unwrap();
  70
  71    respond_tx
  72        .send(model_response(
  73            &request,
  74            indoc! {r"
  75                --- a/root/1.txt
  76                +++ b/root/1.txt
  77                @@ ... @@
  78                 Hello!
  79                -How
  80                +How are you?
  81                 Bye
  82            "},
  83        ))
  84        .unwrap();
  85
  86    cx.run_until_parked();
  87
  88    ep_store.update(cx, |ep_store, cx| {
  89        let prediction = ep_store
  90            .prediction_at(&buffer1, None, &project, cx)
  91            .unwrap();
  92        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  93    });
  94
  95    ep_store.update(cx, |ep_store, _cx| {
  96        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
  97    });
  98
  99    // Prediction for diagnostic in another file
 100
 101    let diagnostic = lsp::Diagnostic {
 102        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 103        severity: Some(lsp::DiagnosticSeverity::ERROR),
 104        message: "Sentence is incomplete".to_string(),
 105        ..Default::default()
 106    };
 107
 108    project.update(cx, |project, cx| {
 109        project.lsp_store().update(cx, |lsp_store, cx| {
 110            lsp_store
 111                .update_diagnostics(
 112                    LanguageServerId(0),
 113                    lsp::PublishDiagnosticsParams {
 114                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 115                        diagnostics: vec![diagnostic],
 116                        version: None,
 117                    },
 118                    None,
 119                    language::DiagnosticSourceKind::Pushed,
 120                    &[],
 121                    cx,
 122                )
 123                .unwrap();
 124        });
 125    });
 126
 127    let (request, respond_tx) = requests.predict.next().await.unwrap();
 128    respond_tx
 129        .send(model_response(
 130            &request,
 131            indoc! {r#"
 132                --- a/root/2.txt
 133                +++ b/root/2.txt
 134                @@ ... @@
 135                 Hola!
 136                -Como
 137                +Como estas?
 138                 Adios
 139            "#},
 140        ))
 141        .unwrap();
 142    cx.run_until_parked();
 143
 144    ep_store.update(cx, |ep_store, cx| {
 145        let prediction = ep_store
 146            .prediction_at(&buffer1, None, &project, cx)
 147            .unwrap();
 148        assert_matches!(
 149            prediction,
 150            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 151        );
 152    });
 153
 154    let buffer2 = project
 155        .update(cx, |project, cx| {
 156            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 157            project.open_buffer(path, cx)
 158        })
 159        .await
 160        .unwrap();
 161
 162    ep_store.update(cx, |ep_store, cx| {
 163        let prediction = ep_store
 164            .prediction_at(&buffer2, None, &project, cx)
 165            .unwrap();
 166        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 167    });
 168}
 169
 170#[gpui::test]
 171async fn test_simple_request(cx: &mut TestAppContext) {
 172    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 173    let fs = FakeFs::new(cx.executor());
 174    fs.insert_tree(
 175        "/root",
 176        json!({
 177            "foo.md":  "Hello!\nHow\nBye\n"
 178        }),
 179    )
 180    .await;
 181    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 182
 183    let buffer = project
 184        .update(cx, |project, cx| {
 185            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 186            project.open_buffer(path, cx)
 187        })
 188        .await
 189        .unwrap();
 190    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 191    let position = snapshot.anchor_before(language::Point::new(1, 3));
 192
 193    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 194        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 195    });
 196
 197    let (request, respond_tx) = requests.predict.next().await.unwrap();
 198
 199    // TODO Put back when we have a structured request again
 200    // assert_eq!(
 201    //     request.excerpt_path.as_ref(),
 202    //     Path::new(path!("root/foo.md"))
 203    // );
 204    // assert_eq!(
 205    //     request.cursor_point,
 206    //     Point {
 207    //         line: Line(1),
 208    //         column: 3
 209    //     }
 210    // );
 211
 212    respond_tx
 213        .send(model_response(
 214            &request,
 215            indoc! { r"
 216                --- a/root/foo.md
 217                +++ b/root/foo.md
 218                @@ ... @@
 219                 Hello!
 220                -How
 221                +How are you?
 222                 Bye
 223            "},
 224        ))
 225        .unwrap();
 226
 227    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 228
 229    assert_eq!(prediction.edits.len(), 1);
 230    assert_eq!(
 231        prediction.edits[0].0.to_point(&snapshot).start,
 232        language::Point::new(1, 3)
 233    );
 234    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 235}
 236
 237#[gpui::test]
 238async fn test_request_events(cx: &mut TestAppContext) {
 239    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 240    let fs = FakeFs::new(cx.executor());
 241    fs.insert_tree(
 242        "/root",
 243        json!({
 244            "foo.md": "Hello!\n\nBye\n"
 245        }),
 246    )
 247    .await;
 248    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 249
 250    let buffer = project
 251        .update(cx, |project, cx| {
 252            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 253            project.open_buffer(path, cx)
 254        })
 255        .await
 256        .unwrap();
 257
 258    ep_store.update(cx, |ep_store, cx| {
 259        ep_store.register_buffer(&buffer, &project, cx);
 260    });
 261
 262    buffer.update(cx, |buffer, cx| {
 263        buffer.edit(vec![(7..7, "How")], None, cx);
 264    });
 265
 266    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 267    let position = snapshot.anchor_before(language::Point::new(1, 3));
 268
 269    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 270        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 271    });
 272
 273    let (request, respond_tx) = requests.predict.next().await.unwrap();
 274
 275    let prompt = prompt_from_request(&request);
 276    assert!(
 277        prompt.contains(indoc! {"
 278        --- a/root/foo.md
 279        +++ b/root/foo.md
 280        @@ -1,3 +1,3 @@
 281         Hello!
 282        -
 283        +How
 284         Bye
 285    "}),
 286        "{prompt}"
 287    );
 288
 289    respond_tx
 290        .send(model_response(
 291            &request,
 292            indoc! {r#"
 293                --- a/root/foo.md
 294                +++ b/root/foo.md
 295                @@ ... @@
 296                 Hello!
 297                -How
 298                +How are you?
 299                 Bye
 300        "#},
 301        ))
 302        .unwrap();
 303
 304    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 305
 306    assert_eq!(prediction.edits.len(), 1);
 307    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 308}
 309
 310#[gpui::test]
 311async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 312    let (ep_store, _requests) = init_test_with_fake_client(cx);
 313    let fs = FakeFs::new(cx.executor());
 314    fs.insert_tree(
 315        "/root",
 316        json!({
 317            "foo.md": "Hello!\n\nBye\n"
 318        }),
 319    )
 320    .await;
 321    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 322
 323    let buffer = project
 324        .update(cx, |project, cx| {
 325            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 326            project.open_buffer(path, cx)
 327        })
 328        .await
 329        .unwrap();
 330
 331    ep_store.update(cx, |ep_store, cx| {
 332        ep_store.register_buffer(&buffer, &project, cx);
 333    });
 334
 335    // First burst: insert "How"
 336    buffer.update(cx, |buffer, cx| {
 337        buffer.edit(vec![(7..7, "How")], None, cx);
 338    });
 339
 340    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 341    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 342    cx.run_until_parked();
 343
 344    // Second burst: append " are you?" immediately after "How" on the same line.
 345    //
 346    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 347    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 348    buffer.update(cx, |buffer, cx| {
 349        buffer.edit(vec![(10..10, " are you?")], None, cx);
 350    });
 351
 352    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 353    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 354    buffer.update(cx, |buffer, cx| {
 355        buffer.edit(vec![(19..19, "!")], None, cx);
 356    });
 357
 358    // Without time-based splitting, there is one event.
 359    let events = ep_store.update(cx, |ep_store, cx| {
 360        ep_store.edit_history_for_project(&project, cx)
 361    });
 362    assert_eq!(events.len(), 1);
 363    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 364    assert_eq!(
 365        diff.as_str(),
 366        indoc! {"
 367            @@ -1,3 +1,3 @@
 368             Hello!
 369            -
 370            +How are you?!
 371             Bye
 372        "}
 373    );
 374
 375    // With time-based splitting, there are two distinct events.
 376    let events = ep_store.update(cx, |ep_store, cx| {
 377        ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
 378    });
 379    assert_eq!(events.len(), 2);
 380    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 381    assert_eq!(
 382        diff.as_str(),
 383        indoc! {"
 384            @@ -1,3 +1,3 @@
 385             Hello!
 386            -
 387            +How
 388             Bye
 389        "}
 390    );
 391
 392    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 393    assert_eq!(
 394        diff.as_str(),
 395        indoc! {"
 396            @@ -1,3 +1,3 @@
 397             Hello!
 398            -How
 399            +How are you?!
 400             Bye
 401        "}
 402    );
 403}
 404
 405#[gpui::test]
 406async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
 407    let (ep_store, _requests) = init_test_with_fake_client(cx);
 408    let fs = FakeFs::new(cx.executor());
 409
 410    // Create a file with 30 lines to test line-based coalescing
 411    let content = (1..=30)
 412        .map(|i| format!("Line {}\n", i))
 413        .collect::<String>();
 414    fs.insert_tree(
 415        "/root",
 416        json!({
 417            "foo.md": content
 418        }),
 419    )
 420    .await;
 421    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 422
 423    let buffer = project
 424        .update(cx, |project, cx| {
 425            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 426            project.open_buffer(path, cx)
 427        })
 428        .await
 429        .unwrap();
 430
 431    ep_store.update(cx, |ep_store, cx| {
 432        ep_store.register_buffer(&buffer, &project, cx);
 433    });
 434
 435    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 436    buffer.update(cx, |buffer, cx| {
 437        let start = Point::new(10, 0).to_offset(buffer);
 438        let end = Point::new(13, 0).to_offset(buffer);
 439        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 440    });
 441
 442    let events = ep_store.update(cx, |ep_store, cx| {
 443        ep_store.edit_history_for_project(&project, cx)
 444    });
 445    assert_eq!(
 446        render_events(&events),
 447        indoc! {"
 448            @@ -8,9 +8,8 @@
 449             Line 8
 450             Line 9
 451             Line 10
 452            -Line 11
 453            -Line 12
 454            -Line 13
 455            +Middle A
 456            +Middle B
 457             Line 14
 458             Line 15
 459             Line 16
 460        "},
 461        "After first edit"
 462    );
 463
 464    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 465    // This tests that coalescing considers the START of the existing range
 466    buffer.update(cx, |buffer, cx| {
 467        let offset = Point::new(5, 0).to_offset(buffer);
 468        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 469    });
 470
 471    let events = ep_store.update(cx, |ep_store, cx| {
 472        ep_store.edit_history_for_project(&project, cx)
 473    });
 474    assert_eq!(
 475        render_events(&events),
 476        indoc! {"
 477            @@ -3,14 +3,14 @@
 478             Line 3
 479             Line 4
 480             Line 5
 481            +Above
 482             Line 6
 483             Line 7
 484             Line 8
 485             Line 9
 486             Line 10
 487            -Line 11
 488            -Line 12
 489            -Line 13
 490            +Middle A
 491            +Middle B
 492             Line 14
 493             Line 15
 494             Line 16
 495        "},
 496        "After inserting above (should coalesce)"
 497    );
 498
 499    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 500    // This tests that coalescing considers the END of the existing range
 501    buffer.update(cx, |buffer, cx| {
 502        let offset = Point::new(14, 0).to_offset(buffer);
 503        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 504    });
 505
 506    let events = ep_store.update(cx, |ep_store, cx| {
 507        ep_store.edit_history_for_project(&project, cx)
 508    });
 509    assert_eq!(
 510        render_events(&events),
 511        indoc! {"
 512            @@ -3,15 +3,16 @@
 513             Line 3
 514             Line 4
 515             Line 5
 516            +Above
 517             Line 6
 518             Line 7
 519             Line 8
 520             Line 9
 521             Line 10
 522            -Line 11
 523            -Line 12
 524            -Line 13
 525            +Middle A
 526            +Middle B
 527             Line 14
 528            +Below
 529             Line 15
 530             Line 16
 531             Line 17
 532        "},
 533        "After inserting below (should coalesce)"
 534    );
 535
 536    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 537    // This should NOT coalesce - creates a new event
 538    buffer.update(cx, |buffer, cx| {
 539        let offset = Point::new(25, 0).to_offset(buffer);
 540        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 541    });
 542
 543    let events = ep_store.update(cx, |ep_store, cx| {
 544        ep_store.edit_history_for_project(&project, cx)
 545    });
 546    assert_eq!(
 547        render_events(&events),
 548        indoc! {"
 549            @@ -3,15 +3,16 @@
 550             Line 3
 551             Line 4
 552             Line 5
 553            +Above
 554             Line 6
 555             Line 7
 556             Line 8
 557             Line 9
 558             Line 10
 559            -Line 11
 560            -Line 12
 561            -Line 13
 562            +Middle A
 563            +Middle B
 564             Line 14
 565            +Below
 566             Line 15
 567             Line 16
 568             Line 17
 569
 570            ---
 571            @@ -23,6 +23,7 @@
 572             Line 22
 573             Line 23
 574             Line 24
 575            +Far below
 576             Line 25
 577             Line 26
 578             Line 27
 579        "},
 580        "After inserting far below (should NOT coalesce)"
 581    );
 582}
 583
 584fn render_events(events: &[StoredEvent]) -> String {
 585    events
 586        .iter()
 587        .map(|e| {
 588            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 589            diff.as_str()
 590        })
 591        .collect::<Vec<_>>()
 592        .join("\n---\n")
 593}
 594
 595#[gpui::test]
 596async fn test_empty_prediction(cx: &mut TestAppContext) {
 597    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 598    let fs = FakeFs::new(cx.executor());
 599    fs.insert_tree(
 600        "/root",
 601        json!({
 602            "foo.md":  "Hello!\nHow\nBye\n"
 603        }),
 604    )
 605    .await;
 606    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 607
 608    let buffer = project
 609        .update(cx, |project, cx| {
 610            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 611            project.open_buffer(path, cx)
 612        })
 613        .await
 614        .unwrap();
 615    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 616    let position = snapshot.anchor_before(language::Point::new(1, 3));
 617
 618    ep_store.update(cx, |ep_store, cx| {
 619        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 620    });
 621
 622    let (request, respond_tx) = requests.predict.next().await.unwrap();
 623    let response = model_response(&request, "");
 624    let id = response.request_id.clone();
 625    respond_tx.send(response).unwrap();
 626
 627    cx.run_until_parked();
 628
 629    ep_store.update(cx, |ep_store, cx| {
 630        assert!(
 631            ep_store
 632                .prediction_at(&buffer, None, &project, cx)
 633                .is_none()
 634        );
 635    });
 636
 637    // prediction is reported as rejected
 638    let (reject_request, _) = requests.reject.next().await.unwrap();
 639
 640    assert_eq!(
 641        &reject_request.rejections,
 642        &[EditPredictionRejection {
 643            request_id: id,
 644            reason: EditPredictionRejectReason::Empty,
 645            was_shown: false
 646        }]
 647    );
 648}
 649
 650#[gpui::test]
 651async fn test_interpolated_empty(cx: &mut TestAppContext) {
 652    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 653    let fs = FakeFs::new(cx.executor());
 654    fs.insert_tree(
 655        "/root",
 656        json!({
 657            "foo.md":  "Hello!\nHow\nBye\n"
 658        }),
 659    )
 660    .await;
 661    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 662
 663    let buffer = project
 664        .update(cx, |project, cx| {
 665            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 666            project.open_buffer(path, cx)
 667        })
 668        .await
 669        .unwrap();
 670    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 671    let position = snapshot.anchor_before(language::Point::new(1, 3));
 672
 673    ep_store.update(cx, |ep_store, cx| {
 674        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 675    });
 676
 677    let (request, respond_tx) = requests.predict.next().await.unwrap();
 678
 679    buffer.update(cx, |buffer, cx| {
 680        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 681    });
 682
 683    let response = model_response(&request, SIMPLE_DIFF);
 684    let id = response.request_id.clone();
 685    respond_tx.send(response).unwrap();
 686
 687    cx.run_until_parked();
 688
 689    ep_store.update(cx, |ep_store, cx| {
 690        assert!(
 691            ep_store
 692                .prediction_at(&buffer, None, &project, cx)
 693                .is_none()
 694        );
 695    });
 696
 697    // prediction is reported as rejected
 698    let (reject_request, _) = requests.reject.next().await.unwrap();
 699
 700    assert_eq!(
 701        &reject_request.rejections,
 702        &[EditPredictionRejection {
 703            request_id: id,
 704            reason: EditPredictionRejectReason::InterpolatedEmpty,
 705            was_shown: false
 706        }]
 707    );
 708}
 709
 710const SIMPLE_DIFF: &str = indoc! { r"
 711    --- a/root/foo.md
 712    +++ b/root/foo.md
 713    @@ ... @@
 714     Hello!
 715    -How
 716    +How are you?
 717     Bye
 718"};
 719
 720#[gpui::test]
 721async fn test_replace_current(cx: &mut TestAppContext) {
 722    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 723    let fs = FakeFs::new(cx.executor());
 724    fs.insert_tree(
 725        "/root",
 726        json!({
 727            "foo.md":  "Hello!\nHow\nBye\n"
 728        }),
 729    )
 730    .await;
 731    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 732
 733    let buffer = project
 734        .update(cx, |project, cx| {
 735            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 736            project.open_buffer(path, cx)
 737        })
 738        .await
 739        .unwrap();
 740    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 741    let position = snapshot.anchor_before(language::Point::new(1, 3));
 742
 743    ep_store.update(cx, |ep_store, cx| {
 744        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 745    });
 746
 747    let (request, respond_tx) = requests.predict.next().await.unwrap();
 748    let first_response = model_response(&request, SIMPLE_DIFF);
 749    let first_id = first_response.request_id.clone();
 750    respond_tx.send(first_response).unwrap();
 751
 752    cx.run_until_parked();
 753
 754    ep_store.update(cx, |ep_store, cx| {
 755        assert_eq!(
 756            ep_store
 757                .prediction_at(&buffer, None, &project, cx)
 758                .unwrap()
 759                .id
 760                .0,
 761            first_id
 762        );
 763    });
 764
 765    // a second request is triggered
 766    ep_store.update(cx, |ep_store, cx| {
 767        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 768    });
 769
 770    let (request, respond_tx) = requests.predict.next().await.unwrap();
 771    let second_response = model_response(&request, SIMPLE_DIFF);
 772    let second_id = second_response.request_id.clone();
 773    respond_tx.send(second_response).unwrap();
 774
 775    cx.run_until_parked();
 776
 777    ep_store.update(cx, |ep_store, cx| {
 778        // second replaces first
 779        assert_eq!(
 780            ep_store
 781                .prediction_at(&buffer, None, &project, cx)
 782                .unwrap()
 783                .id
 784                .0,
 785            second_id
 786        );
 787    });
 788
 789    // first is reported as replaced
 790    let (reject_request, _) = requests.reject.next().await.unwrap();
 791
 792    assert_eq!(
 793        &reject_request.rejections,
 794        &[EditPredictionRejection {
 795            request_id: first_id,
 796            reason: EditPredictionRejectReason::Replaced,
 797            was_shown: false
 798        }]
 799    );
 800}
 801
 802#[gpui::test]
 803async fn test_current_preferred(cx: &mut TestAppContext) {
 804    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 805    let fs = FakeFs::new(cx.executor());
 806    fs.insert_tree(
 807        "/root",
 808        json!({
 809            "foo.md":  "Hello!\nHow\nBye\n"
 810        }),
 811    )
 812    .await;
 813    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 814
 815    let buffer = project
 816        .update(cx, |project, cx| {
 817            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 818            project.open_buffer(path, cx)
 819        })
 820        .await
 821        .unwrap();
 822    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 823    let position = snapshot.anchor_before(language::Point::new(1, 3));
 824
 825    ep_store.update(cx, |ep_store, cx| {
 826        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 827    });
 828
 829    let (request, respond_tx) = requests.predict.next().await.unwrap();
 830    let first_response = model_response(&request, SIMPLE_DIFF);
 831    let first_id = first_response.request_id.clone();
 832    respond_tx.send(first_response).unwrap();
 833
 834    cx.run_until_parked();
 835
 836    ep_store.update(cx, |ep_store, cx| {
 837        assert_eq!(
 838            ep_store
 839                .prediction_at(&buffer, None, &project, cx)
 840                .unwrap()
 841                .id
 842                .0,
 843            first_id
 844        );
 845    });
 846
 847    // a second request is triggered
 848    ep_store.update(cx, |ep_store, cx| {
 849        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 850    });
 851
 852    let (request, respond_tx) = requests.predict.next().await.unwrap();
 853    // worse than current prediction
 854    let second_response = model_response(
 855        &request,
 856        indoc! { r"
 857            --- a/root/foo.md
 858            +++ b/root/foo.md
 859            @@ ... @@
 860             Hello!
 861            -How
 862            +How are
 863             Bye
 864        "},
 865    );
 866    let second_id = second_response.request_id.clone();
 867    respond_tx.send(second_response).unwrap();
 868
 869    cx.run_until_parked();
 870
 871    ep_store.update(cx, |ep_store, cx| {
 872        // first is preferred over second
 873        assert_eq!(
 874            ep_store
 875                .prediction_at(&buffer, None, &project, cx)
 876                .unwrap()
 877                .id
 878                .0,
 879            first_id
 880        );
 881    });
 882
 883    // second is reported as rejected
 884    let (reject_request, _) = requests.reject.next().await.unwrap();
 885
 886    assert_eq!(
 887        &reject_request.rejections,
 888        &[EditPredictionRejection {
 889            request_id: second_id,
 890            reason: EditPredictionRejectReason::CurrentPreferred,
 891            was_shown: false
 892        }]
 893    );
 894}
 895
 896#[gpui::test]
 897async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
 898    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 899    let fs = FakeFs::new(cx.executor());
 900    fs.insert_tree(
 901        "/root",
 902        json!({
 903            "foo.md":  "Hello!\nHow\nBye\n"
 904        }),
 905    )
 906    .await;
 907    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 908
 909    let buffer = project
 910        .update(cx, |project, cx| {
 911            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 912            project.open_buffer(path, cx)
 913        })
 914        .await
 915        .unwrap();
 916    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 917    let position = snapshot.anchor_before(language::Point::new(1, 3));
 918
 919    // start two refresh tasks
 920    ep_store.update(cx, |ep_store, cx| {
 921        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 922    });
 923
 924    let (request1, respond_first) = requests.predict.next().await.unwrap();
 925
 926    ep_store.update(cx, |ep_store, cx| {
 927        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 928    });
 929
 930    let (request, respond_second) = requests.predict.next().await.unwrap();
 931
 932    // wait for throttle
 933    cx.run_until_parked();
 934
 935    // second responds first
 936    let second_response = model_response(&request, SIMPLE_DIFF);
 937    let second_id = second_response.request_id.clone();
 938    respond_second.send(second_response).unwrap();
 939
 940    cx.run_until_parked();
 941
 942    ep_store.update(cx, |ep_store, cx| {
 943        // current prediction is second
 944        assert_eq!(
 945            ep_store
 946                .prediction_at(&buffer, None, &project, cx)
 947                .unwrap()
 948                .id
 949                .0,
 950            second_id
 951        );
 952    });
 953
 954    let first_response = model_response(&request1, SIMPLE_DIFF);
 955    let first_id = first_response.request_id.clone();
 956    respond_first.send(first_response).unwrap();
 957
 958    cx.run_until_parked();
 959
 960    ep_store.update(cx, |ep_store, cx| {
 961        // current prediction is still second, since first was cancelled
 962        assert_eq!(
 963            ep_store
 964                .prediction_at(&buffer, None, &project, cx)
 965                .unwrap()
 966                .id
 967                .0,
 968            second_id
 969        );
 970    });
 971
 972    // first is reported as rejected
 973    let (reject_request, _) = requests.reject.next().await.unwrap();
 974
 975    cx.run_until_parked();
 976
 977    assert_eq!(
 978        &reject_request.rejections,
 979        &[EditPredictionRejection {
 980            request_id: first_id,
 981            reason: EditPredictionRejectReason::Canceled,
 982            was_shown: false
 983        }]
 984    );
 985}
 986
 987#[gpui::test]
 988async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
 989    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 990    let fs = FakeFs::new(cx.executor());
 991    fs.insert_tree(
 992        "/root",
 993        json!({
 994            "foo.md":  "Hello!\nHow\nBye\n"
 995        }),
 996    )
 997    .await;
 998    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 999
1000    let buffer = project
1001        .update(cx, |project, cx| {
1002            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1003            project.open_buffer(path, cx)
1004        })
1005        .await
1006        .unwrap();
1007    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1008    let position = snapshot.anchor_before(language::Point::new(1, 3));
1009
1010    // start two refresh tasks
1011    ep_store.update(cx, |ep_store, cx| {
1012        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1013    });
1014
1015    let (request1, respond_first) = requests.predict.next().await.unwrap();
1016
1017    ep_store.update(cx, |ep_store, cx| {
1018        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1019    });
1020
1021    let (request2, respond_second) = requests.predict.next().await.unwrap();
1022
1023    // wait for throttle, so requests are sent
1024    cx.run_until_parked();
1025
1026    ep_store.update(cx, |ep_store, cx| {
1027        // start a third request
1028        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1029
1030        // 2 are pending, so 2nd is cancelled
1031        assert_eq!(
1032            ep_store
1033                .get_or_init_project(&project, cx)
1034                .cancelled_predictions
1035                .iter()
1036                .copied()
1037                .collect::<Vec<_>>(),
1038            [1]
1039        );
1040    });
1041
1042    // wait for throttle
1043    cx.run_until_parked();
1044
1045    let (request3, respond_third) = requests.predict.next().await.unwrap();
1046
1047    let first_response = model_response(&request1, SIMPLE_DIFF);
1048    let first_id = first_response.request_id.clone();
1049    respond_first.send(first_response).unwrap();
1050
1051    cx.run_until_parked();
1052
1053    ep_store.update(cx, |ep_store, cx| {
1054        // current prediction is first
1055        assert_eq!(
1056            ep_store
1057                .prediction_at(&buffer, None, &project, cx)
1058                .unwrap()
1059                .id
1060                .0,
1061            first_id
1062        );
1063    });
1064
1065    let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1066    let cancelled_id = cancelled_response.request_id.clone();
1067    respond_second.send(cancelled_response).unwrap();
1068
1069    cx.run_until_parked();
1070
1071    ep_store.update(cx, |ep_store, cx| {
1072        // current prediction is still first, since second was cancelled
1073        assert_eq!(
1074            ep_store
1075                .prediction_at(&buffer, None, &project, cx)
1076                .unwrap()
1077                .id
1078                .0,
1079            first_id
1080        );
1081    });
1082
1083    let third_response = model_response(&request3, SIMPLE_DIFF);
1084    let third_response_id = third_response.request_id.clone();
1085    respond_third.send(third_response).unwrap();
1086
1087    cx.run_until_parked();
1088
1089    ep_store.update(cx, |ep_store, cx| {
1090        // third completes and replaces first
1091        assert_eq!(
1092            ep_store
1093                .prediction_at(&buffer, None, &project, cx)
1094                .unwrap()
1095                .id
1096                .0,
1097            third_response_id
1098        );
1099    });
1100
1101    // second is reported as rejected
1102    let (reject_request, _) = requests.reject.next().await.unwrap();
1103
1104    cx.run_until_parked();
1105
1106    assert_eq!(
1107        &reject_request.rejections,
1108        &[
1109            EditPredictionRejection {
1110                request_id: cancelled_id,
1111                reason: EditPredictionRejectReason::Canceled,
1112                was_shown: false
1113            },
1114            EditPredictionRejection {
1115                request_id: first_id,
1116                reason: EditPredictionRejectReason::Replaced,
1117                was_shown: false
1118            }
1119        ]
1120    );
1121}
1122
1123#[gpui::test]
1124async fn test_rejections_flushing(cx: &mut TestAppContext) {
1125    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1126
1127    ep_store.update(cx, |ep_store, _cx| {
1128        ep_store.reject_prediction(
1129            EditPredictionId("test-1".into()),
1130            EditPredictionRejectReason::Discarded,
1131            false,
1132        );
1133        ep_store.reject_prediction(
1134            EditPredictionId("test-2".into()),
1135            EditPredictionRejectReason::Canceled,
1136            true,
1137        );
1138    });
1139
1140    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1141    cx.run_until_parked();
1142
1143    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1144    respond_tx.send(()).unwrap();
1145
1146    // batched
1147    assert_eq!(reject_request.rejections.len(), 2);
1148    assert_eq!(
1149        reject_request.rejections[0],
1150        EditPredictionRejection {
1151            request_id: "test-1".to_string(),
1152            reason: EditPredictionRejectReason::Discarded,
1153            was_shown: false
1154        }
1155    );
1156    assert_eq!(
1157        reject_request.rejections[1],
1158        EditPredictionRejection {
1159            request_id: "test-2".to_string(),
1160            reason: EditPredictionRejectReason::Canceled,
1161            was_shown: true
1162        }
1163    );
1164
1165    // Reaching batch size limit sends without debounce
1166    ep_store.update(cx, |ep_store, _cx| {
1167        for i in 0..70 {
1168            ep_store.reject_prediction(
1169                EditPredictionId(format!("batch-{}", i).into()),
1170                EditPredictionRejectReason::Discarded,
1171                false,
1172            );
1173        }
1174    });
1175
1176    // First MAX/2 items are sent immediately
1177    cx.run_until_parked();
1178    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1179    respond_tx.send(()).unwrap();
1180
1181    assert_eq!(reject_request.rejections.len(), 50);
1182    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1183    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1184
1185    // Remaining items are debounced with the next batch
1186    cx.executor().advance_clock(Duration::from_secs(15));
1187    cx.run_until_parked();
1188
1189    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1190    respond_tx.send(()).unwrap();
1191
1192    assert_eq!(reject_request.rejections.len(), 20);
1193    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1194    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1195
1196    // Request failure
1197    ep_store.update(cx, |ep_store, _cx| {
1198        ep_store.reject_prediction(
1199            EditPredictionId("retry-1".into()),
1200            EditPredictionRejectReason::Discarded,
1201            false,
1202        );
1203    });
1204
1205    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1206    cx.run_until_parked();
1207
1208    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1209    assert_eq!(reject_request.rejections.len(), 1);
1210    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1211    // Simulate failure
1212    drop(_respond_tx);
1213
1214    // Add another rejection
1215    ep_store.update(cx, |ep_store, _cx| {
1216        ep_store.reject_prediction(
1217            EditPredictionId("retry-2".into()),
1218            EditPredictionRejectReason::Discarded,
1219            false,
1220        );
1221    });
1222
1223    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1224    cx.run_until_parked();
1225
1226    // Retry should include both the failed item and the new one
1227    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1228    respond_tx.send(()).unwrap();
1229
1230    assert_eq!(reject_request.rejections.len(), 2);
1231    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1232    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1233}
1234
1235// Skipped until we start including diagnostics in prompt
1236// #[gpui::test]
1237// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1238//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1239//     let fs = FakeFs::new(cx.executor());
1240//     fs.insert_tree(
1241//         "/root",
1242//         json!({
1243//             "foo.md": "Hello!\nBye"
1244//         }),
1245//     )
1246//     .await;
1247//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1248
1249//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1250//     let diagnostic = lsp::Diagnostic {
1251//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1252//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1253//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1254//         ..Default::default()
1255//     };
1256
1257//     project.update(cx, |project, cx| {
1258//         project.lsp_store().update(cx, |lsp_store, cx| {
1259//             // Create some diagnostics
1260//             lsp_store
1261//                 .update_diagnostics(
1262//                     LanguageServerId(0),
1263//                     lsp::PublishDiagnosticsParams {
1264//                         uri: path_to_buffer_uri.clone(),
1265//                         diagnostics: vec![diagnostic],
1266//                         version: None,
1267//                     },
1268//                     None,
1269//                     language::DiagnosticSourceKind::Pushed,
1270//                     &[],
1271//                     cx,
1272//                 )
1273//                 .unwrap();
1274//         });
1275//     });
1276
1277//     let buffer = project
1278//         .update(cx, |project, cx| {
1279//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1280//             project.open_buffer(path, cx)
1281//         })
1282//         .await
1283//         .unwrap();
1284
1285//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1286//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1287
1288//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1289//         ep_store.request_prediction(&project, &buffer, position, cx)
1290//     });
1291
1292//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1293
1294//     assert_eq!(request.diagnostic_groups.len(), 1);
1295//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1296//         .unwrap();
1297//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1298//     assert_eq!(
1299//         value,
1300//         json!({
1301//             "entries": [{
1302//                 "range": {
1303//                     "start": 8,
1304//                     "end": 10
1305//                 },
1306//                 "diagnostic": {
1307//                     "source": null,
1308//                     "code": null,
1309//                     "code_description": null,
1310//                     "severity": 1,
1311//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1312//                     "markdown": null,
1313//                     "group_id": 0,
1314//                     "is_primary": true,
1315//                     "is_disk_based": false,
1316//                     "is_unnecessary": false,
1317//                     "source_kind": "Pushed",
1318//                     "data": null,
1319//                     "underline": true
1320//                 }
1321//             }],
1322//             "primary_ix": 0
1323//         })
1324//     );
1325// }
1326
1327// Generate a model response that would apply the given diff to the active file.
1328fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1329    let excerpt =
1330        request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1331    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1332
1333    PredictEditsV3Response {
1334        request_id: Uuid::new_v4().to_string(),
1335        output: new_excerpt,
1336    }
1337}
1338
1339fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1340    zeta_prompt::format_zeta_prompt(&request.input, request.prompt_version)
1341}
1342
1343struct RequestChannels {
1344    predict: mpsc::UnboundedReceiver<(
1345        PredictEditsV3Request,
1346        oneshot::Sender<PredictEditsV3Response>,
1347    )>,
1348    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1349}
1350
1351fn init_test_with_fake_client(
1352    cx: &mut TestAppContext,
1353) -> (Entity<EditPredictionStore>, RequestChannels) {
1354    cx.update(move |cx| {
1355        let settings_store = SettingsStore::test(cx);
1356        cx.set_global(settings_store);
1357        zlog::init_test();
1358
1359        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1360        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1361
1362        let http_client = FakeHttpClient::create({
1363            move |req| {
1364                let uri = req.uri().path().to_string();
1365                let mut body = req.into_body();
1366                let predict_req_tx = predict_req_tx.clone();
1367                let reject_req_tx = reject_req_tx.clone();
1368                async move {
1369                    let resp = match uri.as_str() {
1370                        "/client/llm_tokens" => serde_json::to_string(&json!({
1371                            "token": "test"
1372                        }))
1373                        .unwrap(),
1374                        "/predict_edits/v3" => {
1375                            let mut buf = Vec::new();
1376                            body.read_to_end(&mut buf).await.ok();
1377                            let req = serde_json::from_slice(&buf).unwrap();
1378
1379                            let (res_tx, res_rx) = oneshot::channel();
1380                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1381                            serde_json::to_string(&res_rx.await?).unwrap()
1382                        }
1383                        "/predict_edits/reject" => {
1384                            let mut buf = Vec::new();
1385                            body.read_to_end(&mut buf).await.ok();
1386                            let req = serde_json::from_slice(&buf).unwrap();
1387
1388                            let (res_tx, res_rx) = oneshot::channel();
1389                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1390                            serde_json::to_string(&res_rx.await?).unwrap()
1391                        }
1392                        _ => {
1393                            panic!("Unexpected path: {}", uri)
1394                        }
1395                    };
1396
1397                    Ok(Response::builder().body(resp.into()).unwrap())
1398                }
1399            }
1400        });
1401
1402        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1403        client.cloud_client().set_credentials(1, "test".into());
1404
1405        language_model::init(client.clone(), cx);
1406
1407        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1408        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1409
1410        (
1411            ep_store,
1412            RequestChannels {
1413                predict: predict_req_rx,
1414                reject: reject_req_rx,
1415            },
1416        )
1417    })
1418}
1419
1420const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1421
1422#[gpui::test]
1423async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1424    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1425    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1426        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1427    });
1428
1429    let edit_preview = cx
1430        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1431        .await;
1432
1433    let prediction = EditPrediction {
1434        edits,
1435        edit_preview,
1436        buffer: buffer.clone(),
1437        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1438        id: EditPredictionId("the-id".into()),
1439        inputs: ZetaPromptInput {
1440            events: Default::default(),
1441            related_files: Default::default(),
1442            cursor_path: Path::new("").into(),
1443            cursor_excerpt: "".into(),
1444            editable_range_in_excerpt: 0..0,
1445            cursor_offset_in_excerpt: 0,
1446        },
1447        buffer_snapshotted_at: Instant::now(),
1448        response_received_at: Instant::now(),
1449    };
1450
1451    cx.update(|cx| {
1452        assert_eq!(
1453            from_completion_edits(
1454                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1455                &buffer,
1456                cx
1457            ),
1458            vec![(2..5, "REM".into()), (9..11, "".into())]
1459        );
1460
1461        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1462        assert_eq!(
1463            from_completion_edits(
1464                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1465                &buffer,
1466                cx
1467            ),
1468            vec![(2..2, "REM".into()), (6..8, "".into())]
1469        );
1470
1471        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1472        assert_eq!(
1473            from_completion_edits(
1474                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1475                &buffer,
1476                cx
1477            ),
1478            vec![(2..5, "REM".into()), (9..11, "".into())]
1479        );
1480
1481        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1482        assert_eq!(
1483            from_completion_edits(
1484                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1485                &buffer,
1486                cx
1487            ),
1488            vec![(3..3, "EM".into()), (7..9, "".into())]
1489        );
1490
1491        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1492        assert_eq!(
1493            from_completion_edits(
1494                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1495                &buffer,
1496                cx
1497            ),
1498            vec![(4..4, "M".into()), (8..10, "".into())]
1499        );
1500
1501        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1502        assert_eq!(
1503            from_completion_edits(
1504                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1505                &buffer,
1506                cx
1507            ),
1508            vec![(9..11, "".into())]
1509        );
1510
1511        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1512        assert_eq!(
1513            from_completion_edits(
1514                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1515                &buffer,
1516                cx
1517            ),
1518            vec![(4..4, "M".into()), (8..10, "".into())]
1519        );
1520
1521        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1522        assert_eq!(
1523            from_completion_edits(
1524                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1525                &buffer,
1526                cx
1527            ),
1528            vec![(4..4, "M".into())]
1529        );
1530
1531        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1532        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1533    })
1534}
1535
1536#[gpui::test]
1537async fn test_clean_up_diff(cx: &mut TestAppContext) {
1538    init_test(cx);
1539
1540    assert_eq!(
1541        apply_edit_prediction(
1542            indoc! {"
1543                    fn main() {
1544                        let word_1 = \"lorem\";
1545                        let range = word.len()..word.len();
1546                    }
1547                "},
1548            indoc! {"
1549                    <|editable_region_start|>
1550                    fn main() {
1551                        let word_1 = \"lorem\";
1552                        let range = word_1.len()..word_1.len();
1553                    }
1554
1555                    <|editable_region_end|>
1556                "},
1557            cx,
1558        )
1559        .await,
1560        indoc! {"
1561                fn main() {
1562                    let word_1 = \"lorem\";
1563                    let range = word_1.len()..word_1.len();
1564                }
1565            "},
1566    );
1567
1568    assert_eq!(
1569        apply_edit_prediction(
1570            indoc! {"
1571                    fn main() {
1572                        let story = \"the quick\"
1573                    }
1574                "},
1575            indoc! {"
1576                    <|editable_region_start|>
1577                    fn main() {
1578                        let story = \"the quick brown fox jumps over the lazy dog\";
1579                    }
1580
1581                    <|editable_region_end|>
1582                "},
1583            cx,
1584        )
1585        .await,
1586        indoc! {"
1587                fn main() {
1588                    let story = \"the quick brown fox jumps over the lazy dog\";
1589                }
1590            "},
1591    );
1592}
1593
1594#[gpui::test]
1595async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1596    init_test(cx);
1597
1598    let buffer_content = "lorem\n";
1599    let completion_response = indoc! {"
1600            ```animals.js
1601            <|start_of_file|>
1602            <|editable_region_start|>
1603            lorem
1604            ipsum
1605            <|editable_region_end|>
1606            ```"};
1607
1608    assert_eq!(
1609        apply_edit_prediction(buffer_content, completion_response, cx).await,
1610        "lorem\nipsum"
1611    );
1612}
1613
1614#[gpui::test]
1615async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1616    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1617    // When the buffer ends without a trailing newline, but the model returns output
1618    // with a trailing newline, zeta2 should normalize both sides before diffing
1619    // so no spurious newline is inserted.
1620    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1621    let fs = FakeFs::new(cx.executor());
1622
1623    // Single line buffer with no trailing newline
1624    fs.insert_tree(
1625        "/root",
1626        json!({
1627            "foo.txt": "hello"
1628        }),
1629    )
1630    .await;
1631    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1632
1633    let buffer = project
1634        .update(cx, |project, cx| {
1635            let path = project
1636                .find_project_path(path!("root/foo.txt"), cx)
1637                .unwrap();
1638            project.open_buffer(path, cx)
1639        })
1640        .await
1641        .unwrap();
1642
1643    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1644    let position = snapshot.anchor_before(language::Point::new(0, 5));
1645
1646    ep_store.update(cx, |ep_store, cx| {
1647        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1648    });
1649
1650    let (_request, respond_tx) = requests.predict.next().await.unwrap();
1651
1652    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1653    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1654    let response = PredictEditsV3Response {
1655        request_id: Uuid::new_v4().to_string(),
1656        output: "hello world\n".to_string(),
1657    };
1658    respond_tx.send(response).unwrap();
1659
1660    cx.run_until_parked();
1661
1662    // The prediction should insert " world" without adding a newline
1663    ep_store.update(cx, |ep_store, cx| {
1664        let prediction = ep_store
1665            .prediction_at(&buffer, None, &project, cx)
1666            .expect("should have prediction");
1667        let edits: Vec<_> = prediction
1668            .edits
1669            .iter()
1670            .map(|(range, text)| {
1671                let snapshot = buffer.read(cx).snapshot();
1672                (range.to_offset(&snapshot), text.clone())
1673            })
1674            .collect();
1675        assert_eq!(edits, vec![(5..5, " world".into())]);
1676    });
1677}
1678
1679#[gpui::test]
1680async fn test_can_collect_data(cx: &mut TestAppContext) {
1681    init_test(cx);
1682
1683    let fs = project::FakeFs::new(cx.executor());
1684    fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1685        .await;
1686
1687    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1688    let buffer = project
1689        .update(cx, |project, cx| {
1690            project.open_local_buffer(path!("/project/src/main.rs"), cx)
1691        })
1692        .await
1693        .unwrap();
1694
1695    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1696    ep_store.update(cx, |ep_store, _cx| {
1697        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1698    });
1699
1700    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1701    assert_eq!(
1702        captured_request.lock().clone().unwrap().can_collect_data,
1703        true
1704    );
1705
1706    ep_store.update(cx, |ep_store, _cx| {
1707        ep_store.data_collection_choice = DataCollectionChoice::Disabled
1708    });
1709
1710    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1711    assert_eq!(
1712        captured_request.lock().clone().unwrap().can_collect_data,
1713        false
1714    );
1715}
1716
1717#[gpui::test]
1718async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1719    init_test(cx);
1720
1721    let fs = project::FakeFs::new(cx.executor());
1722    let project = Project::test(fs.clone(), [], cx).await;
1723
1724    let buffer = cx.new(|_cx| {
1725        Buffer::remote(
1726            language::BufferId::new(1).unwrap(),
1727            ReplicaId::new(1),
1728            language::Capability::ReadWrite,
1729            "fn main() {\n    println!(\"Hello\");\n}",
1730        )
1731    });
1732
1733    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1734    ep_store.update(cx, |ep_store, _cx| {
1735        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1736    });
1737
1738    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1739    assert_eq!(
1740        captured_request.lock().clone().unwrap().can_collect_data,
1741        false
1742    );
1743}
1744
1745#[gpui::test]
1746async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1747    init_test(cx);
1748
1749    let fs = project::FakeFs::new(cx.executor());
1750    fs.insert_tree(
1751        path!("/project"),
1752        json!({
1753            "LICENSE": BSD_0_TXT,
1754            ".env": "SECRET_KEY=secret"
1755        }),
1756    )
1757    .await;
1758
1759    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1760    let buffer = project
1761        .update(cx, |project, cx| {
1762            project.open_local_buffer("/project/.env", cx)
1763        })
1764        .await
1765        .unwrap();
1766
1767    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1768    ep_store.update(cx, |ep_store, _cx| {
1769        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1770    });
1771
1772    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1773    assert_eq!(
1774        captured_request.lock().clone().unwrap().can_collect_data,
1775        false
1776    );
1777}
1778
1779#[gpui::test]
1780async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1781    init_test(cx);
1782
1783    let fs = project::FakeFs::new(cx.executor());
1784    let project = Project::test(fs.clone(), [], cx).await;
1785    let buffer = cx.new(|cx| Buffer::local("", cx));
1786
1787    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1788    ep_store.update(cx, |ep_store, _cx| {
1789        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1790    });
1791
1792    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1793    assert_eq!(
1794        captured_request.lock().clone().unwrap().can_collect_data,
1795        false
1796    );
1797}
1798
1799#[gpui::test]
1800async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1801    init_test(cx);
1802
1803    let fs = project::FakeFs::new(cx.executor());
1804    fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1805        .await;
1806
1807    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1808    let buffer = project
1809        .update(cx, |project, cx| {
1810            project.open_local_buffer("/project/main.rs", cx)
1811        })
1812        .await
1813        .unwrap();
1814
1815    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1816    ep_store.update(cx, |ep_store, _cx| {
1817        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1818    });
1819
1820    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1821    assert_eq!(
1822        captured_request.lock().clone().unwrap().can_collect_data,
1823        false
1824    );
1825}
1826
1827#[gpui::test]
1828async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1829    init_test(cx);
1830
1831    let fs = project::FakeFs::new(cx.executor());
1832    fs.insert_tree(
1833        path!("/open_source_worktree"),
1834        json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1835    )
1836    .await;
1837    fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1838        .await;
1839
1840    let project = Project::test(
1841        fs.clone(),
1842        [
1843            path!("/open_source_worktree").as_ref(),
1844            path!("/closed_source_worktree").as_ref(),
1845        ],
1846        cx,
1847    )
1848    .await;
1849    let buffer = project
1850        .update(cx, |project, cx| {
1851            project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1852        })
1853        .await
1854        .unwrap();
1855
1856    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1857    ep_store.update(cx, |ep_store, _cx| {
1858        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1859    });
1860
1861    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1862    assert_eq!(
1863        captured_request.lock().clone().unwrap().can_collect_data,
1864        true
1865    );
1866
1867    let closed_source_file = project
1868        .update(cx, |project, cx| {
1869            let worktree2 = project
1870                .worktree_for_root_name("closed_source_worktree", cx)
1871                .unwrap();
1872            worktree2.update(cx, |worktree2, cx| {
1873                worktree2.load_file(rel_path("main.rs"), cx)
1874            })
1875        })
1876        .await
1877        .unwrap()
1878        .file;
1879
1880    buffer.update(cx, |buffer, cx| {
1881        buffer.file_updated(closed_source_file, cx);
1882    });
1883
1884    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1885    assert_eq!(
1886        captured_request.lock().clone().unwrap().can_collect_data,
1887        false
1888    );
1889}
1890
1891#[gpui::test]
1892async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1893    init_test(cx);
1894
1895    let fs = project::FakeFs::new(cx.executor());
1896    fs.insert_tree(
1897        path!("/worktree1"),
1898        json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1899    )
1900    .await;
1901    fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1902        .await;
1903
1904    let project = Project::test(
1905        fs.clone(),
1906        [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1907        cx,
1908    )
1909    .await;
1910    let buffer = project
1911        .update(cx, |project, cx| {
1912            project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1913        })
1914        .await
1915        .unwrap();
1916    let private_buffer = project
1917        .update(cx, |project, cx| {
1918            project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1919        })
1920        .await
1921        .unwrap();
1922
1923    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1924    ep_store.update(cx, |ep_store, _cx| {
1925        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1926    });
1927
1928    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1929    assert_eq!(
1930        captured_request.lock().clone().unwrap().can_collect_data,
1931        true
1932    );
1933
1934    // this has a side effect of registering the buffer to watch for edits
1935    run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1936    assert_eq!(
1937        captured_request.lock().clone().unwrap().can_collect_data,
1938        false
1939    );
1940
1941    private_buffer.update(cx, |private_buffer, cx| {
1942        private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1943    });
1944
1945    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1946    assert_eq!(
1947        captured_request.lock().clone().unwrap().can_collect_data,
1948        false
1949    );
1950
1951    // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1952    // included
1953    buffer.update(cx, |buffer, cx| {
1954        buffer.edit(
1955            [(
1956                0..0,
1957                " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1958            )],
1959            None,
1960            cx,
1961        );
1962    });
1963
1964    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1965    assert_eq!(
1966        captured_request.lock().clone().unwrap().can_collect_data,
1967        true
1968    );
1969}
1970
1971fn init_test(cx: &mut TestAppContext) {
1972    cx.update(|cx| {
1973        let settings_store = SettingsStore::test(cx);
1974        cx.set_global(settings_store);
1975    });
1976}
1977
1978async fn apply_edit_prediction(
1979    buffer_content: &str,
1980    completion_response: &str,
1981    cx: &mut TestAppContext,
1982) -> String {
1983    let fs = project::FakeFs::new(cx.executor());
1984    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1985    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1986    let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1987    *response.lock() = completion_response.to_string();
1988    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1989    buffer.update(cx, |buffer, cx| {
1990        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1991    });
1992    buffer.read_with(cx, |buffer, _| buffer.text())
1993}
1994
1995async fn run_edit_prediction(
1996    buffer: &Entity<Buffer>,
1997    project: &Entity<Project>,
1998    ep_store: &Entity<EditPredictionStore>,
1999    cx: &mut TestAppContext,
2000) -> EditPrediction {
2001    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2002    ep_store.update(cx, |ep_store, cx| {
2003        ep_store.register_buffer(buffer, &project, cx)
2004    });
2005    cx.background_executor.run_until_parked();
2006    let prediction_task = ep_store.update(cx, |ep_store, cx| {
2007        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2008    });
2009    prediction_task.await.unwrap().unwrap().prediction.unwrap()
2010}
2011
2012async fn make_test_ep_store(
2013    project: &Entity<Project>,
2014    cx: &mut TestAppContext,
2015) -> (
2016    Entity<EditPredictionStore>,
2017    Arc<Mutex<Option<PredictEditsBody>>>,
2018    Arc<Mutex<String>>,
2019) {
2020    let default_response = indoc! {"
2021            ```main.rs
2022            <|start_of_file|>
2023            <|editable_region_start|>
2024            hello world
2025            <|editable_region_end|>
2026            ```"
2027    };
2028    let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2029    let completion_response: Arc<Mutex<String>> =
2030        Arc::new(Mutex::new(default_response.to_string()));
2031    let http_client = FakeHttpClient::create({
2032        let captured_request = captured_request.clone();
2033        let completion_response = completion_response.clone();
2034        let mut next_request_id = 0;
2035        move |req| {
2036            let captured_request = captured_request.clone();
2037            let completion_response = completion_response.clone();
2038            async move {
2039                match (req.method(), req.uri().path()) {
2040                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2041                        .status(200)
2042                        .body(
2043                            serde_json::to_string(&CreateLlmTokenResponse {
2044                                token: LlmToken("the-llm-token".to_string()),
2045                            })
2046                            .unwrap()
2047                            .into(),
2048                        )
2049                        .unwrap()),
2050                    (&Method::POST, "/predict_edits/v2") => {
2051                        let mut request_body = String::new();
2052                        req.into_body().read_to_string(&mut request_body).await?;
2053                        *captured_request.lock() =
2054                            Some(serde_json::from_str(&request_body).unwrap());
2055                        next_request_id += 1;
2056                        Ok(http_client::Response::builder()
2057                            .status(200)
2058                            .body(
2059                                serde_json::to_string(&PredictEditsResponse {
2060                                    request_id: format!("request-{next_request_id}"),
2061                                    output_excerpt: completion_response.lock().clone(),
2062                                })
2063                                .unwrap()
2064                                .into(),
2065                            )
2066                            .unwrap())
2067                    }
2068                    _ => Ok(http_client::Response::builder()
2069                        .status(404)
2070                        .body("Not Found".into())
2071                        .unwrap()),
2072                }
2073            }
2074        }
2075    });
2076
2077    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2078    cx.update(|cx| {
2079        RefreshLlmTokenListener::register(client.clone(), cx);
2080    });
2081    let _server = FakeServer::for_client(42, &client, cx).await;
2082
2083    let ep_store = cx.new(|cx| {
2084        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2085        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2086
2087        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2088        for worktree in worktrees {
2089            let worktree_id = worktree.read(cx).id();
2090            ep_store
2091                .get_or_init_project(project, cx)
2092                .license_detection_watchers
2093                .entry(worktree_id)
2094                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2095        }
2096
2097        ep_store
2098    });
2099
2100    (ep_store, captured_request, completion_response)
2101}
2102
2103fn to_completion_edits(
2104    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2105    buffer: &Entity<Buffer>,
2106    cx: &App,
2107) -> Vec<(Range<Anchor>, Arc<str>)> {
2108    let buffer = buffer.read(cx);
2109    iterator
2110        .into_iter()
2111        .map(|(range, text)| {
2112            (
2113                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2114                text,
2115            )
2116        })
2117        .collect()
2118}
2119
2120fn from_completion_edits(
2121    editor_edits: &[(Range<Anchor>, Arc<str>)],
2122    buffer: &Entity<Buffer>,
2123    cx: &App,
2124) -> Vec<(Range<usize>, Arc<str>)> {
2125    let buffer = buffer.read(cx);
2126    editor_edits
2127        .iter()
2128        .map(|(range, text)| {
2129            (
2130                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2131                text.clone(),
2132            )
2133        })
2134        .collect()
2135}
2136
2137#[gpui::test]
2138async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2139    init_test(cx);
2140
2141    let fs = FakeFs::new(cx.executor());
2142    fs.insert_tree(
2143        "/project",
2144        serde_json::json!({
2145            "main.rs": "fn main() {\n    \n}\n"
2146        }),
2147    )
2148    .await;
2149
2150    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2151
2152    let http_client = FakeHttpClient::create(|_req| async move {
2153        Ok(gpui::http_client::Response::builder()
2154            .status(401)
2155            .body("Unauthorized".into())
2156            .unwrap())
2157    });
2158
2159    let client =
2160        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2161    cx.update(|cx| {
2162        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2163    });
2164
2165    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2166
2167    let buffer = project
2168        .update(cx, |project, cx| {
2169            let path = project
2170                .find_project_path(path!("/project/main.rs"), cx)
2171                .unwrap();
2172            project.open_buffer(path, cx)
2173        })
2174        .await
2175        .unwrap();
2176
2177    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2178    ep_store.update(cx, |ep_store, cx| {
2179        ep_store.register_buffer(&buffer, &project, cx)
2180    });
2181    cx.background_executor.run_until_parked();
2182
2183    let completion_task = ep_store.update(cx, |ep_store, cx| {
2184        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2185        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2186    });
2187
2188    let result = completion_task.await;
2189    assert!(
2190        result.is_err(),
2191        "Without authentication and without custom URL, prediction should fail"
2192    );
2193}
2194
2195#[gpui::test]
2196async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2197    init_test(cx);
2198
2199    let fs = FakeFs::new(cx.executor());
2200    fs.insert_tree(
2201        "/project",
2202        serde_json::json!({
2203            "main.rs": "fn main() {\n    \n}\n"
2204        }),
2205    )
2206    .await;
2207
2208    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2209
2210    let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2211    let predict_called_clone = predict_called.clone();
2212
2213    let http_client = FakeHttpClient::create({
2214        move |req| {
2215            let uri = req.uri().path().to_string();
2216            let predict_called = predict_called_clone.clone();
2217            async move {
2218                if uri.contains("predict") {
2219                    predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2220                    Ok(gpui::http_client::Response::builder()
2221                        .body(
2222                            serde_json::to_string(&open_ai::Response {
2223                                id: "test-123".to_string(),
2224                                object: "chat.completion".to_string(),
2225                                created: 0,
2226                                model: "test".to_string(),
2227                                usage: open_ai::Usage {
2228                                    prompt_tokens: 0,
2229                                    completion_tokens: 0,
2230                                    total_tokens: 0,
2231                                },
2232                                choices: vec![open_ai::Choice {
2233                                    index: 0,
2234                                    message: open_ai::RequestMessage::Assistant {
2235                                        content: Some(open_ai::MessageContent::Plain(
2236                                            indoc! {"
2237                                                ```main.rs
2238                                                <|start_of_file|>
2239                                                <|editable_region_start|>
2240                                                fn main() {
2241                                                    println!(\"Hello, world!\");
2242                                                }
2243                                                <|editable_region_end|>
2244                                                ```
2245                                            "}
2246                                            .to_string(),
2247                                        )),
2248                                        tool_calls: vec![],
2249                                    },
2250                                    finish_reason: Some("stop".to_string()),
2251                                }],
2252                            })
2253                            .unwrap()
2254                            .into(),
2255                        )
2256                        .unwrap())
2257                } else {
2258                    Ok(gpui::http_client::Response::builder()
2259                        .status(401)
2260                        .body("Unauthorized".into())
2261                        .unwrap())
2262                }
2263            }
2264        }
2265    });
2266
2267    let client =
2268        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2269    cx.update(|cx| {
2270        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2271    });
2272
2273    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2274
2275    let buffer = project
2276        .update(cx, |project, cx| {
2277            let path = project
2278                .find_project_path(path!("/project/main.rs"), cx)
2279                .unwrap();
2280            project.open_buffer(path, cx)
2281        })
2282        .await
2283        .unwrap();
2284
2285    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2286    ep_store.update(cx, |ep_store, cx| {
2287        ep_store.register_buffer(&buffer, &project, cx)
2288    });
2289    cx.background_executor.run_until_parked();
2290
2291    let completion_task = ep_store.update(cx, |ep_store, cx| {
2292        ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2293        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2294        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2295    });
2296
2297    let _ = completion_task.await;
2298
2299    assert!(
2300        predict_called.load(std::sync::atomic::Ordering::SeqCst),
2301        "With custom URL, predict endpoint should be called even without authentication"
2302    );
2303}
2304
2305#[gpui::test]
2306fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2307    let buffer = cx.new(|cx| {
2308        Buffer::local(
2309            indoc! {"
2310                zero
2311                one
2312                two
2313                three
2314                four
2315                five
2316                six
2317                seven
2318                eight
2319                nine
2320                ten
2321                eleven
2322                twelve
2323                thirteen
2324                fourteen
2325                fifteen
2326                sixteen
2327                seventeen
2328                eighteen
2329                nineteen
2330                twenty
2331                twenty-one
2332                twenty-two
2333                twenty-three
2334                twenty-four
2335            "},
2336            cx,
2337        )
2338    });
2339
2340    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2341
2342    buffer.update(cx, |buffer, cx| {
2343        let point = Point::new(12, 0);
2344        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2345        let point = Point::new(8, 0);
2346        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2347    });
2348
2349    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2350
2351    let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2352
2353    assert_eq!(
2354        diff,
2355        indoc! {"
2356            @@ -6,10 +6,12 @@
2357             five
2358             six
2359             seven
2360            +FIRST INSERTION
2361             eight
2362             nine
2363             ten
2364             eleven
2365            +SECOND INSERTION
2366             twelve
2367             thirteen
2368             fourteen
2369            "}
2370    );
2371}
2372
2373#[ctor::ctor]
2374fn init_logger() {
2375    zlog::init_test();
2376}