edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
   3use client::{UserStore, test::FakeServer};
   4use clock::{FakeSystemClock, ReplicaId};
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
   8    RejectEditPredictionsBody,
   9};
  10use futures::{
  11    AsyncReadExt, StreamExt,
  12    channel::{mpsc, oneshot},
  13};
  14use gpui::{
  15    Entity, TestAppContext,
  16    http_client::{FakeHttpClient, Response},
  17};
  18use indoc::indoc;
  19use language::Point;
  20use lsp::LanguageServerId;
  21use open_ai::Usage;
  22use parking_lot::Mutex;
  23use pretty_assertions::{assert_eq, assert_matches};
  24use project::{FakeFs, Project};
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{path::Path, sync::Arc, time::Duration};
  28use util::{path, rel_path::rel_path};
  29use uuid::Uuid;
  30use zeta_prompt::ZetaPromptInput;
  31
  32use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  33
  34#[gpui::test]
  35async fn test_current_state(cx: &mut TestAppContext) {
  36    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  37    let fs = FakeFs::new(cx.executor());
  38    fs.insert_tree(
  39        "/root",
  40        json!({
  41            "1.txt": "Hello!\nHow\nBye\n",
  42            "2.txt": "Hola!\nComo\nAdios\n"
  43        }),
  44    )
  45    .await;
  46    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  47
  48    let buffer1 = project
  49        .update(cx, |project, cx| {
  50            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  51            project.set_active_path(Some(path.clone()), cx);
  52            project.open_buffer(path, cx)
  53        })
  54        .await
  55        .unwrap();
  56    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  57    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  58
  59    ep_store.update(cx, |ep_store, cx| {
  60        ep_store.register_project(&project, cx);
  61        ep_store.register_buffer(&buffer1, &project, cx);
  62    });
  63
  64    // Prediction for current file
  65
  66    ep_store.update(cx, |ep_store, cx| {
  67        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  68    });
  69    let (request, respond_tx) = requests.predict.next().await.unwrap();
  70
  71    respond_tx
  72        .send(model_response(
  73            request,
  74            indoc! {r"
  75                --- a/root/1.txt
  76                +++ b/root/1.txt
  77                @@ ... @@
  78                 Hello!
  79                -How
  80                +How are you?
  81                 Bye
  82            "},
  83        ))
  84        .unwrap();
  85
  86    cx.run_until_parked();
  87
  88    ep_store.update(cx, |ep_store, cx| {
  89        let prediction = ep_store
  90            .prediction_at(&buffer1, None, &project, cx)
  91            .unwrap();
  92        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  93    });
  94
  95    ep_store.update(cx, |ep_store, _cx| {
  96        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
  97    });
  98
  99    // Prediction for diagnostic in another file
 100
 101    let diagnostic = lsp::Diagnostic {
 102        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 103        severity: Some(lsp::DiagnosticSeverity::ERROR),
 104        message: "Sentence is incomplete".to_string(),
 105        ..Default::default()
 106    };
 107
 108    project.update(cx, |project, cx| {
 109        project.lsp_store().update(cx, |lsp_store, cx| {
 110            lsp_store
 111                .update_diagnostics(
 112                    LanguageServerId(0),
 113                    lsp::PublishDiagnosticsParams {
 114                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 115                        diagnostics: vec![diagnostic],
 116                        version: None,
 117                    },
 118                    None,
 119                    language::DiagnosticSourceKind::Pushed,
 120                    &[],
 121                    cx,
 122                )
 123                .unwrap();
 124        });
 125    });
 126
 127    let (request, respond_tx) = requests.predict.next().await.unwrap();
 128    respond_tx
 129        .send(model_response(
 130            request,
 131            indoc! {r#"
 132                --- a/root/2.txt
 133                +++ b/root/2.txt
 134                @@ ... @@
 135                 Hola!
 136                -Como
 137                +Como estas?
 138                 Adios
 139            "#},
 140        ))
 141        .unwrap();
 142    cx.run_until_parked();
 143
 144    ep_store.update(cx, |ep_store, cx| {
 145        let prediction = ep_store
 146            .prediction_at(&buffer1, None, &project, cx)
 147            .unwrap();
 148        assert_matches!(
 149            prediction,
 150            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 151        );
 152    });
 153
 154    let buffer2 = project
 155        .update(cx, |project, cx| {
 156            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 157            project.open_buffer(path, cx)
 158        })
 159        .await
 160        .unwrap();
 161
 162    ep_store.update(cx, |ep_store, cx| {
 163        let prediction = ep_store
 164            .prediction_at(&buffer2, None, &project, cx)
 165            .unwrap();
 166        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 167    });
 168}
 169
 170#[gpui::test]
 171async fn test_simple_request(cx: &mut TestAppContext) {
 172    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 173    let fs = FakeFs::new(cx.executor());
 174    fs.insert_tree(
 175        "/root",
 176        json!({
 177            "foo.md":  "Hello!\nHow\nBye\n"
 178        }),
 179    )
 180    .await;
 181    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 182
 183    let buffer = project
 184        .update(cx, |project, cx| {
 185            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 186            project.open_buffer(path, cx)
 187        })
 188        .await
 189        .unwrap();
 190    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 191    let position = snapshot.anchor_before(language::Point::new(1, 3));
 192
 193    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 194        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 195    });
 196
 197    let (request, respond_tx) = requests.predict.next().await.unwrap();
 198
 199    // TODO Put back when we have a structured request again
 200    // assert_eq!(
 201    //     request.excerpt_path.as_ref(),
 202    //     Path::new(path!("root/foo.md"))
 203    // );
 204    // assert_eq!(
 205    //     request.cursor_point,
 206    //     Point {
 207    //         line: Line(1),
 208    //         column: 3
 209    //     }
 210    // );
 211
 212    respond_tx
 213        .send(model_response(
 214            request,
 215            indoc! { r"
 216                --- a/root/foo.md
 217                +++ b/root/foo.md
 218                @@ ... @@
 219                 Hello!
 220                -How
 221                +How are you?
 222                 Bye
 223            "},
 224        ))
 225        .unwrap();
 226
 227    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 228
 229    assert_eq!(prediction.edits.len(), 1);
 230    assert_eq!(
 231        prediction.edits[0].0.to_point(&snapshot).start,
 232        language::Point::new(1, 3)
 233    );
 234    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 235}
 236
 237#[gpui::test]
 238async fn test_request_events(cx: &mut TestAppContext) {
 239    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 240    let fs = FakeFs::new(cx.executor());
 241    fs.insert_tree(
 242        "/root",
 243        json!({
 244            "foo.md": "Hello!\n\nBye\n"
 245        }),
 246    )
 247    .await;
 248    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 249
 250    let buffer = project
 251        .update(cx, |project, cx| {
 252            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 253            project.open_buffer(path, cx)
 254        })
 255        .await
 256        .unwrap();
 257
 258    ep_store.update(cx, |ep_store, cx| {
 259        ep_store.register_buffer(&buffer, &project, cx);
 260    });
 261
 262    buffer.update(cx, |buffer, cx| {
 263        buffer.edit(vec![(7..7, "How")], None, cx);
 264    });
 265
 266    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 267    let position = snapshot.anchor_before(language::Point::new(1, 3));
 268
 269    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 270        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 271    });
 272
 273    let (request, respond_tx) = requests.predict.next().await.unwrap();
 274
 275    let prompt = prompt_from_request(&request);
 276    assert!(
 277        prompt.contains(indoc! {"
 278        --- a/root/foo.md
 279        +++ b/root/foo.md
 280        @@ -1,3 +1,3 @@
 281         Hello!
 282        -
 283        +How
 284         Bye
 285    "}),
 286        "{prompt}"
 287    );
 288
 289    respond_tx
 290        .send(model_response(
 291            request,
 292            indoc! {r#"
 293                --- a/root/foo.md
 294                +++ b/root/foo.md
 295                @@ ... @@
 296                 Hello!
 297                -How
 298                +How are you?
 299                 Bye
 300        "#},
 301        ))
 302        .unwrap();
 303
 304    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 305
 306    assert_eq!(prediction.edits.len(), 1);
 307    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 308}
 309
 310#[gpui::test]
 311async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 312    let (ep_store, _requests) = init_test_with_fake_client(cx);
 313    let fs = FakeFs::new(cx.executor());
 314    fs.insert_tree(
 315        "/root",
 316        json!({
 317            "foo.md": "Hello!\n\nBye\n"
 318        }),
 319    )
 320    .await;
 321    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 322
 323    let buffer = project
 324        .update(cx, |project, cx| {
 325            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 326            project.open_buffer(path, cx)
 327        })
 328        .await
 329        .unwrap();
 330
 331    ep_store.update(cx, |ep_store, cx| {
 332        ep_store.register_buffer(&buffer, &project, cx);
 333    });
 334
 335    // First burst: insert "How"
 336    buffer.update(cx, |buffer, cx| {
 337        buffer.edit(vec![(7..7, "How")], None, cx);
 338    });
 339
 340    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 341    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 342    cx.run_until_parked();
 343
 344    // Second burst: append " are you?" immediately after "How" on the same line.
 345    //
 346    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 347    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 348    buffer.update(cx, |buffer, cx| {
 349        buffer.edit(vec![(10..10, " are you?")], None, cx);
 350    });
 351
 352    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 353    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 354    buffer.update(cx, |buffer, cx| {
 355        buffer.edit(vec![(19..19, "!")], None, cx);
 356    });
 357
 358    // Without time-based splitting, there is one event.
 359    let events = ep_store.update(cx, |ep_store, cx| {
 360        ep_store.edit_history_for_project(&project, cx)
 361    });
 362    assert_eq!(events.len(), 1);
 363    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 364    assert_eq!(
 365        diff.as_str(),
 366        indoc! {"
 367            @@ -1,3 +1,3 @@
 368             Hello!
 369            -
 370            +How are you?!
 371             Bye
 372        "}
 373    );
 374
 375    // With time-based splitting, there are two distinct events.
 376    let events = ep_store.update(cx, |ep_store, cx| {
 377        ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
 378    });
 379    assert_eq!(events.len(), 2);
 380    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 381    assert_eq!(
 382        diff.as_str(),
 383        indoc! {"
 384            @@ -1,3 +1,3 @@
 385             Hello!
 386            -
 387            +How
 388             Bye
 389        "}
 390    );
 391
 392    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 393    assert_eq!(
 394        diff.as_str(),
 395        indoc! {"
 396            @@ -1,3 +1,3 @@
 397             Hello!
 398            -How
 399            +How are you?!
 400             Bye
 401        "}
 402    );
 403}
 404
 405#[gpui::test]
 406async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
 407    let (ep_store, _requests) = init_test_with_fake_client(cx);
 408    let fs = FakeFs::new(cx.executor());
 409
 410    // Create a file with 30 lines to test line-based coalescing
 411    let content = (1..=30)
 412        .map(|i| format!("Line {}\n", i))
 413        .collect::<String>();
 414    fs.insert_tree(
 415        "/root",
 416        json!({
 417            "foo.md": content
 418        }),
 419    )
 420    .await;
 421    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 422
 423    let buffer = project
 424        .update(cx, |project, cx| {
 425            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 426            project.open_buffer(path, cx)
 427        })
 428        .await
 429        .unwrap();
 430
 431    ep_store.update(cx, |ep_store, cx| {
 432        ep_store.register_buffer(&buffer, &project, cx);
 433    });
 434
 435    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 436    buffer.update(cx, |buffer, cx| {
 437        let start = Point::new(10, 0).to_offset(buffer);
 438        let end = Point::new(13, 0).to_offset(buffer);
 439        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 440    });
 441
 442    let events = ep_store.update(cx, |ep_store, cx| {
 443        ep_store.edit_history_for_project(&project, cx)
 444    });
 445    assert_eq!(
 446        render_events(&events),
 447        indoc! {"
 448            @@ -8,9 +8,8 @@
 449             Line 8
 450             Line 9
 451             Line 10
 452            -Line 11
 453            -Line 12
 454            -Line 13
 455            +Middle A
 456            +Middle B
 457             Line 14
 458             Line 15
 459             Line 16
 460        "},
 461        "After first edit"
 462    );
 463
 464    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 465    // This tests that coalescing considers the START of the existing range
 466    buffer.update(cx, |buffer, cx| {
 467        let offset = Point::new(5, 0).to_offset(buffer);
 468        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 469    });
 470
 471    let events = ep_store.update(cx, |ep_store, cx| {
 472        ep_store.edit_history_for_project(&project, cx)
 473    });
 474    assert_eq!(
 475        render_events(&events),
 476        indoc! {"
 477            @@ -3,14 +3,14 @@
 478             Line 3
 479             Line 4
 480             Line 5
 481            +Above
 482             Line 6
 483             Line 7
 484             Line 8
 485             Line 9
 486             Line 10
 487            -Line 11
 488            -Line 12
 489            -Line 13
 490            +Middle A
 491            +Middle B
 492             Line 14
 493             Line 15
 494             Line 16
 495        "},
 496        "After inserting above (should coalesce)"
 497    );
 498
 499    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 500    // This tests that coalescing considers the END of the existing range
 501    buffer.update(cx, |buffer, cx| {
 502        let offset = Point::new(14, 0).to_offset(buffer);
 503        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 504    });
 505
 506    let events = ep_store.update(cx, |ep_store, cx| {
 507        ep_store.edit_history_for_project(&project, cx)
 508    });
 509    assert_eq!(
 510        render_events(&events),
 511        indoc! {"
 512            @@ -3,15 +3,16 @@
 513             Line 3
 514             Line 4
 515             Line 5
 516            +Above
 517             Line 6
 518             Line 7
 519             Line 8
 520             Line 9
 521             Line 10
 522            -Line 11
 523            -Line 12
 524            -Line 13
 525            +Middle A
 526            +Middle B
 527             Line 14
 528            +Below
 529             Line 15
 530             Line 16
 531             Line 17
 532        "},
 533        "After inserting below (should coalesce)"
 534    );
 535
 536    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 537    // This should NOT coalesce - creates a new event
 538    buffer.update(cx, |buffer, cx| {
 539        let offset = Point::new(25, 0).to_offset(buffer);
 540        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 541    });
 542
 543    let events = ep_store.update(cx, |ep_store, cx| {
 544        ep_store.edit_history_for_project(&project, cx)
 545    });
 546    assert_eq!(
 547        render_events(&events),
 548        indoc! {"
 549            @@ -3,15 +3,16 @@
 550             Line 3
 551             Line 4
 552             Line 5
 553            +Above
 554             Line 6
 555             Line 7
 556             Line 8
 557             Line 9
 558             Line 10
 559            -Line 11
 560            -Line 12
 561            -Line 13
 562            +Middle A
 563            +Middle B
 564             Line 14
 565            +Below
 566             Line 15
 567             Line 16
 568             Line 17
 569
 570            ---
 571            @@ -23,6 +23,7 @@
 572             Line 22
 573             Line 23
 574             Line 24
 575            +Far below
 576             Line 25
 577             Line 26
 578             Line 27
 579        "},
 580        "After inserting far below (should NOT coalesce)"
 581    );
 582}
 583
 584fn render_events(events: &[StoredEvent]) -> String {
 585    events
 586        .iter()
 587        .map(|e| {
 588            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 589            diff.as_str()
 590        })
 591        .collect::<Vec<_>>()
 592        .join("\n---\n")
 593}
 594
 595#[gpui::test]
 596async fn test_empty_prediction(cx: &mut TestAppContext) {
 597    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 598    let fs = FakeFs::new(cx.executor());
 599    fs.insert_tree(
 600        "/root",
 601        json!({
 602            "foo.md":  "Hello!\nHow\nBye\n"
 603        }),
 604    )
 605    .await;
 606    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 607
 608    let buffer = project
 609        .update(cx, |project, cx| {
 610            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 611            project.open_buffer(path, cx)
 612        })
 613        .await
 614        .unwrap();
 615    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 616    let position = snapshot.anchor_before(language::Point::new(1, 3));
 617
 618    ep_store.update(cx, |ep_store, cx| {
 619        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 620    });
 621
 622    let (request, respond_tx) = requests.predict.next().await.unwrap();
 623    let response = model_response(request, "");
 624    let id = response.id.clone();
 625    respond_tx.send(response).unwrap();
 626
 627    cx.run_until_parked();
 628
 629    ep_store.update(cx, |ep_store, cx| {
 630        assert!(
 631            ep_store
 632                .prediction_at(&buffer, None, &project, cx)
 633                .is_none()
 634        );
 635    });
 636
 637    // prediction is reported as rejected
 638    let (reject_request, _) = requests.reject.next().await.unwrap();
 639
 640    assert_eq!(
 641        &reject_request.rejections,
 642        &[EditPredictionRejection {
 643            request_id: id,
 644            reason: EditPredictionRejectReason::Empty,
 645            was_shown: false
 646        }]
 647    );
 648}
 649
 650#[gpui::test]
 651async fn test_interpolated_empty(cx: &mut TestAppContext) {
 652    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 653    let fs = FakeFs::new(cx.executor());
 654    fs.insert_tree(
 655        "/root",
 656        json!({
 657            "foo.md":  "Hello!\nHow\nBye\n"
 658        }),
 659    )
 660    .await;
 661    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 662
 663    let buffer = project
 664        .update(cx, |project, cx| {
 665            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 666            project.open_buffer(path, cx)
 667        })
 668        .await
 669        .unwrap();
 670    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 671    let position = snapshot.anchor_before(language::Point::new(1, 3));
 672
 673    ep_store.update(cx, |ep_store, cx| {
 674        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 675    });
 676
 677    let (request, respond_tx) = requests.predict.next().await.unwrap();
 678
 679    buffer.update(cx, |buffer, cx| {
 680        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 681    });
 682
 683    let response = model_response(request, SIMPLE_DIFF);
 684    let id = response.id.clone();
 685    respond_tx.send(response).unwrap();
 686
 687    cx.run_until_parked();
 688
 689    ep_store.update(cx, |ep_store, cx| {
 690        assert!(
 691            ep_store
 692                .prediction_at(&buffer, None, &project, cx)
 693                .is_none()
 694        );
 695    });
 696
 697    // prediction is reported as rejected
 698    let (reject_request, _) = requests.reject.next().await.unwrap();
 699
 700    assert_eq!(
 701        &reject_request.rejections,
 702        &[EditPredictionRejection {
 703            request_id: id,
 704            reason: EditPredictionRejectReason::InterpolatedEmpty,
 705            was_shown: false
 706        }]
 707    );
 708}
 709
 710const SIMPLE_DIFF: &str = indoc! { r"
 711    --- a/root/foo.md
 712    +++ b/root/foo.md
 713    @@ ... @@
 714     Hello!
 715    -How
 716    +How are you?
 717     Bye
 718"};
 719
 720#[gpui::test]
 721async fn test_replace_current(cx: &mut TestAppContext) {
 722    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 723    let fs = FakeFs::new(cx.executor());
 724    fs.insert_tree(
 725        "/root",
 726        json!({
 727            "foo.md":  "Hello!\nHow\nBye\n"
 728        }),
 729    )
 730    .await;
 731    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 732
 733    let buffer = project
 734        .update(cx, |project, cx| {
 735            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 736            project.open_buffer(path, cx)
 737        })
 738        .await
 739        .unwrap();
 740    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 741    let position = snapshot.anchor_before(language::Point::new(1, 3));
 742
 743    ep_store.update(cx, |ep_store, cx| {
 744        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 745    });
 746
 747    let (request, respond_tx) = requests.predict.next().await.unwrap();
 748    let first_response = model_response(request, SIMPLE_DIFF);
 749    let first_id = first_response.id.clone();
 750    respond_tx.send(first_response).unwrap();
 751
 752    cx.run_until_parked();
 753
 754    ep_store.update(cx, |ep_store, cx| {
 755        assert_eq!(
 756            ep_store
 757                .prediction_at(&buffer, None, &project, cx)
 758                .unwrap()
 759                .id
 760                .0,
 761            first_id
 762        );
 763    });
 764
 765    // a second request is triggered
 766    ep_store.update(cx, |ep_store, cx| {
 767        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 768    });
 769
 770    let (request, respond_tx) = requests.predict.next().await.unwrap();
 771    let second_response = model_response(request, SIMPLE_DIFF);
 772    let second_id = second_response.id.clone();
 773    respond_tx.send(second_response).unwrap();
 774
 775    cx.run_until_parked();
 776
 777    ep_store.update(cx, |ep_store, cx| {
 778        // second replaces first
 779        assert_eq!(
 780            ep_store
 781                .prediction_at(&buffer, None, &project, cx)
 782                .unwrap()
 783                .id
 784                .0,
 785            second_id
 786        );
 787    });
 788
 789    // first is reported as replaced
 790    let (reject_request, _) = requests.reject.next().await.unwrap();
 791
 792    assert_eq!(
 793        &reject_request.rejections,
 794        &[EditPredictionRejection {
 795            request_id: first_id,
 796            reason: EditPredictionRejectReason::Replaced,
 797            was_shown: false
 798        }]
 799    );
 800}
 801
 802#[gpui::test]
 803async fn test_current_preferred(cx: &mut TestAppContext) {
 804    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 805    let fs = FakeFs::new(cx.executor());
 806    fs.insert_tree(
 807        "/root",
 808        json!({
 809            "foo.md":  "Hello!\nHow\nBye\n"
 810        }),
 811    )
 812    .await;
 813    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 814
 815    let buffer = project
 816        .update(cx, |project, cx| {
 817            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 818            project.open_buffer(path, cx)
 819        })
 820        .await
 821        .unwrap();
 822    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 823    let position = snapshot.anchor_before(language::Point::new(1, 3));
 824
 825    ep_store.update(cx, |ep_store, cx| {
 826        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 827    });
 828
 829    let (request, respond_tx) = requests.predict.next().await.unwrap();
 830    let first_response = model_response(request, SIMPLE_DIFF);
 831    let first_id = first_response.id.clone();
 832    respond_tx.send(first_response).unwrap();
 833
 834    cx.run_until_parked();
 835
 836    ep_store.update(cx, |ep_store, cx| {
 837        assert_eq!(
 838            ep_store
 839                .prediction_at(&buffer, None, &project, cx)
 840                .unwrap()
 841                .id
 842                .0,
 843            first_id
 844        );
 845    });
 846
 847    // a second request is triggered
 848    ep_store.update(cx, |ep_store, cx| {
 849        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 850    });
 851
 852    let (request, respond_tx) = requests.predict.next().await.unwrap();
 853    // worse than current prediction
 854    let second_response = model_response(
 855        request,
 856        indoc! { r"
 857            --- a/root/foo.md
 858            +++ b/root/foo.md
 859            @@ ... @@
 860             Hello!
 861            -How
 862            +How are
 863             Bye
 864        "},
 865    );
 866    let second_id = second_response.id.clone();
 867    respond_tx.send(second_response).unwrap();
 868
 869    cx.run_until_parked();
 870
 871    ep_store.update(cx, |ep_store, cx| {
 872        // first is preferred over second
 873        assert_eq!(
 874            ep_store
 875                .prediction_at(&buffer, None, &project, cx)
 876                .unwrap()
 877                .id
 878                .0,
 879            first_id
 880        );
 881    });
 882
 883    // second is reported as rejected
 884    let (reject_request, _) = requests.reject.next().await.unwrap();
 885
 886    assert_eq!(
 887        &reject_request.rejections,
 888        &[EditPredictionRejection {
 889            request_id: second_id,
 890            reason: EditPredictionRejectReason::CurrentPreferred,
 891            was_shown: false
 892        }]
 893    );
 894}
 895
 896#[gpui::test]
 897async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
 898    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 899    let fs = FakeFs::new(cx.executor());
 900    fs.insert_tree(
 901        "/root",
 902        json!({
 903            "foo.md":  "Hello!\nHow\nBye\n"
 904        }),
 905    )
 906    .await;
 907    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 908
 909    let buffer = project
 910        .update(cx, |project, cx| {
 911            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 912            project.open_buffer(path, cx)
 913        })
 914        .await
 915        .unwrap();
 916    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 917    let position = snapshot.anchor_before(language::Point::new(1, 3));
 918
 919    // start two refresh tasks
 920    ep_store.update(cx, |ep_store, cx| {
 921        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 922    });
 923
 924    let (request1, respond_first) = requests.predict.next().await.unwrap();
 925
 926    ep_store.update(cx, |ep_store, cx| {
 927        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 928    });
 929
 930    let (request, respond_second) = requests.predict.next().await.unwrap();
 931
 932    // wait for throttle
 933    cx.run_until_parked();
 934
 935    // second responds first
 936    let second_response = model_response(request, SIMPLE_DIFF);
 937    let second_id = second_response.id.clone();
 938    respond_second.send(second_response).unwrap();
 939
 940    cx.run_until_parked();
 941
 942    ep_store.update(cx, |ep_store, cx| {
 943        // current prediction is second
 944        assert_eq!(
 945            ep_store
 946                .prediction_at(&buffer, None, &project, cx)
 947                .unwrap()
 948                .id
 949                .0,
 950            second_id
 951        );
 952    });
 953
 954    let first_response = model_response(request1, SIMPLE_DIFF);
 955    let first_id = first_response.id.clone();
 956    respond_first.send(first_response).unwrap();
 957
 958    cx.run_until_parked();
 959
 960    ep_store.update(cx, |ep_store, cx| {
 961        // current prediction is still second, since first was cancelled
 962        assert_eq!(
 963            ep_store
 964                .prediction_at(&buffer, None, &project, cx)
 965                .unwrap()
 966                .id
 967                .0,
 968            second_id
 969        );
 970    });
 971
 972    // first is reported as rejected
 973    let (reject_request, _) = requests.reject.next().await.unwrap();
 974
 975    cx.run_until_parked();
 976
 977    assert_eq!(
 978        &reject_request.rejections,
 979        &[EditPredictionRejection {
 980            request_id: first_id,
 981            reason: EditPredictionRejectReason::Canceled,
 982            was_shown: false
 983        }]
 984    );
 985}
 986
 987#[gpui::test]
 988async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
 989    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 990    let fs = FakeFs::new(cx.executor());
 991    fs.insert_tree(
 992        "/root",
 993        json!({
 994            "foo.md":  "Hello!\nHow\nBye\n"
 995        }),
 996    )
 997    .await;
 998    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 999
1000    let buffer = project
1001        .update(cx, |project, cx| {
1002            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1003            project.open_buffer(path, cx)
1004        })
1005        .await
1006        .unwrap();
1007    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1008    let position = snapshot.anchor_before(language::Point::new(1, 3));
1009
1010    // start two refresh tasks
1011    ep_store.update(cx, |ep_store, cx| {
1012        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1013    });
1014
1015    let (request1, respond_first) = requests.predict.next().await.unwrap();
1016
1017    ep_store.update(cx, |ep_store, cx| {
1018        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1019    });
1020
1021    let (request2, respond_second) = requests.predict.next().await.unwrap();
1022
1023    // wait for throttle, so requests are sent
1024    cx.run_until_parked();
1025
1026    ep_store.update(cx, |ep_store, cx| {
1027        // start a third request
1028        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1029
1030        // 2 are pending, so 2nd is cancelled
1031        assert_eq!(
1032            ep_store
1033                .get_or_init_project(&project, cx)
1034                .cancelled_predictions
1035                .iter()
1036                .copied()
1037                .collect::<Vec<_>>(),
1038            [1]
1039        );
1040    });
1041
1042    // wait for throttle
1043    cx.run_until_parked();
1044
1045    let (request3, respond_third) = requests.predict.next().await.unwrap();
1046
1047    let first_response = model_response(request1, SIMPLE_DIFF);
1048    let first_id = first_response.id.clone();
1049    respond_first.send(first_response).unwrap();
1050
1051    cx.run_until_parked();
1052
1053    ep_store.update(cx, |ep_store, cx| {
1054        // current prediction is first
1055        assert_eq!(
1056            ep_store
1057                .prediction_at(&buffer, None, &project, cx)
1058                .unwrap()
1059                .id
1060                .0,
1061            first_id
1062        );
1063    });
1064
1065    let cancelled_response = model_response(request2, SIMPLE_DIFF);
1066    let cancelled_id = cancelled_response.id.clone();
1067    respond_second.send(cancelled_response).unwrap();
1068
1069    cx.run_until_parked();
1070
1071    ep_store.update(cx, |ep_store, cx| {
1072        // current prediction is still first, since second was cancelled
1073        assert_eq!(
1074            ep_store
1075                .prediction_at(&buffer, None, &project, cx)
1076                .unwrap()
1077                .id
1078                .0,
1079            first_id
1080        );
1081    });
1082
1083    let third_response = model_response(request3, SIMPLE_DIFF);
1084    let third_response_id = third_response.id.clone();
1085    respond_third.send(third_response).unwrap();
1086
1087    cx.run_until_parked();
1088
1089    ep_store.update(cx, |ep_store, cx| {
1090        // third completes and replaces first
1091        assert_eq!(
1092            ep_store
1093                .prediction_at(&buffer, None, &project, cx)
1094                .unwrap()
1095                .id
1096                .0,
1097            third_response_id
1098        );
1099    });
1100
1101    // second is reported as rejected
1102    let (reject_request, _) = requests.reject.next().await.unwrap();
1103
1104    cx.run_until_parked();
1105
1106    assert_eq!(
1107        &reject_request.rejections,
1108        &[
1109            EditPredictionRejection {
1110                request_id: cancelled_id,
1111                reason: EditPredictionRejectReason::Canceled,
1112                was_shown: false
1113            },
1114            EditPredictionRejection {
1115                request_id: first_id,
1116                reason: EditPredictionRejectReason::Replaced,
1117                was_shown: false
1118            }
1119        ]
1120    );
1121}
1122
1123#[gpui::test]
1124async fn test_rejections_flushing(cx: &mut TestAppContext) {
1125    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1126
1127    ep_store.update(cx, |ep_store, _cx| {
1128        ep_store.reject_prediction(
1129            EditPredictionId("test-1".into()),
1130            EditPredictionRejectReason::Discarded,
1131            false,
1132        );
1133        ep_store.reject_prediction(
1134            EditPredictionId("test-2".into()),
1135            EditPredictionRejectReason::Canceled,
1136            true,
1137        );
1138    });
1139
1140    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1141    cx.run_until_parked();
1142
1143    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1144    respond_tx.send(()).unwrap();
1145
1146    // batched
1147    assert_eq!(reject_request.rejections.len(), 2);
1148    assert_eq!(
1149        reject_request.rejections[0],
1150        EditPredictionRejection {
1151            request_id: "test-1".to_string(),
1152            reason: EditPredictionRejectReason::Discarded,
1153            was_shown: false
1154        }
1155    );
1156    assert_eq!(
1157        reject_request.rejections[1],
1158        EditPredictionRejection {
1159            request_id: "test-2".to_string(),
1160            reason: EditPredictionRejectReason::Canceled,
1161            was_shown: true
1162        }
1163    );
1164
1165    // Reaching batch size limit sends without debounce
1166    ep_store.update(cx, |ep_store, _cx| {
1167        for i in 0..70 {
1168            ep_store.reject_prediction(
1169                EditPredictionId(format!("batch-{}", i).into()),
1170                EditPredictionRejectReason::Discarded,
1171                false,
1172            );
1173        }
1174    });
1175
1176    // First MAX/2 items are sent immediately
1177    cx.run_until_parked();
1178    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1179    respond_tx.send(()).unwrap();
1180
1181    assert_eq!(reject_request.rejections.len(), 50);
1182    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1183    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1184
1185    // Remaining items are debounced with the next batch
1186    cx.executor().advance_clock(Duration::from_secs(15));
1187    cx.run_until_parked();
1188
1189    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1190    respond_tx.send(()).unwrap();
1191
1192    assert_eq!(reject_request.rejections.len(), 20);
1193    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1194    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1195
1196    // Request failure
1197    ep_store.update(cx, |ep_store, _cx| {
1198        ep_store.reject_prediction(
1199            EditPredictionId("retry-1".into()),
1200            EditPredictionRejectReason::Discarded,
1201            false,
1202        );
1203    });
1204
1205    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1206    cx.run_until_parked();
1207
1208    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1209    assert_eq!(reject_request.rejections.len(), 1);
1210    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1211    // Simulate failure
1212    drop(_respond_tx);
1213
1214    // Add another rejection
1215    ep_store.update(cx, |ep_store, _cx| {
1216        ep_store.reject_prediction(
1217            EditPredictionId("retry-2".into()),
1218            EditPredictionRejectReason::Discarded,
1219            false,
1220        );
1221    });
1222
1223    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1224    cx.run_until_parked();
1225
1226    // Retry should include both the failed item and the new one
1227    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1228    respond_tx.send(()).unwrap();
1229
1230    assert_eq!(reject_request.rejections.len(), 2);
1231    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1232    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1233}
1234
1235// Skipped until we start including diagnostics in prompt
1236// #[gpui::test]
1237// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1238//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1239//     let fs = FakeFs::new(cx.executor());
1240//     fs.insert_tree(
1241//         "/root",
1242//         json!({
1243//             "foo.md": "Hello!\nBye"
1244//         }),
1245//     )
1246//     .await;
1247//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1248
1249//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1250//     let diagnostic = lsp::Diagnostic {
1251//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1252//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1253//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1254//         ..Default::default()
1255//     };
1256
1257//     project.update(cx, |project, cx| {
1258//         project.lsp_store().update(cx, |lsp_store, cx| {
1259//             // Create some diagnostics
1260//             lsp_store
1261//                 .update_diagnostics(
1262//                     LanguageServerId(0),
1263//                     lsp::PublishDiagnosticsParams {
1264//                         uri: path_to_buffer_uri.clone(),
1265//                         diagnostics: vec![diagnostic],
1266//                         version: None,
1267//                     },
1268//                     None,
1269//                     language::DiagnosticSourceKind::Pushed,
1270//                     &[],
1271//                     cx,
1272//                 )
1273//                 .unwrap();
1274//         });
1275//     });
1276
1277//     let buffer = project
1278//         .update(cx, |project, cx| {
1279//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1280//             project.open_buffer(path, cx)
1281//         })
1282//         .await
1283//         .unwrap();
1284
1285//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1286//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1287
1288//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1289//         ep_store.request_prediction(&project, &buffer, position, cx)
1290//     });
1291
1292//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1293
1294//     assert_eq!(request.diagnostic_groups.len(), 1);
1295//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1296//         .unwrap();
1297//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1298//     assert_eq!(
1299//         value,
1300//         json!({
1301//             "entries": [{
1302//                 "range": {
1303//                     "start": 8,
1304//                     "end": 10
1305//                 },
1306//                 "diagnostic": {
1307//                     "source": null,
1308//                     "code": null,
1309//                     "code_description": null,
1310//                     "severity": 1,
1311//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1312//                     "markdown": null,
1313//                     "group_id": 0,
1314//                     "is_primary": true,
1315//                     "is_disk_based": false,
1316//                     "is_unnecessary": false,
1317//                     "source_kind": "Pushed",
1318//                     "data": null,
1319//                     "underline": true
1320//                 }
1321//             }],
1322//             "primary_ix": 0
1323//         })
1324//     );
1325// }
1326
1327// Generate a model response that would apply the given diff to the active file.
1328fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
1329    let prompt = match &request.messages[0] {
1330        open_ai::RequestMessage::User {
1331            content: open_ai::MessageContent::Plain(content),
1332        } => content,
1333        _ => panic!("unexpected request {request:?}"),
1334    };
1335
1336    let open = "<editable_region>\n";
1337    let close = "</editable_region>";
1338    let cursor = "<|user_cursor|>";
1339
1340    let start_ix = open.len() + prompt.find(open).unwrap();
1341    let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
1342    let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
1343    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1344
1345    open_ai::Response {
1346        id: Uuid::new_v4().to_string(),
1347        object: "response".into(),
1348        created: 0,
1349        model: "model".into(),
1350        choices: vec![open_ai::Choice {
1351            index: 0,
1352            message: open_ai::RequestMessage::Assistant {
1353                content: Some(open_ai::MessageContent::Plain(new_excerpt)),
1354                tool_calls: vec![],
1355            },
1356            finish_reason: None,
1357        }],
1358        usage: Usage {
1359            prompt_tokens: 0,
1360            completion_tokens: 0,
1361            total_tokens: 0,
1362        },
1363    }
1364}
1365
1366fn prompt_from_request(request: &open_ai::Request) -> &str {
1367    assert_eq!(request.messages.len(), 1);
1368    let open_ai::RequestMessage::User {
1369        content: open_ai::MessageContent::Plain(content),
1370        ..
1371    } = &request.messages[0]
1372    else {
1373        panic!(
1374            "Request does not have single user message of type Plain. {:#?}",
1375            request
1376        );
1377    };
1378    content
1379}
1380
1381struct RequestChannels {
1382    predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
1383    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1384}
1385
1386fn init_test_with_fake_client(
1387    cx: &mut TestAppContext,
1388) -> (Entity<EditPredictionStore>, RequestChannels) {
1389    cx.update(move |cx| {
1390        let settings_store = SettingsStore::test(cx);
1391        cx.set_global(settings_store);
1392        zlog::init_test();
1393
1394        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1395        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1396
1397        let http_client = FakeHttpClient::create({
1398            move |req| {
1399                let uri = req.uri().path().to_string();
1400                let mut body = req.into_body();
1401                let predict_req_tx = predict_req_tx.clone();
1402                let reject_req_tx = reject_req_tx.clone();
1403                async move {
1404                    let resp = match uri.as_str() {
1405                        "/client/llm_tokens" => serde_json::to_string(&json!({
1406                            "token": "test"
1407                        }))
1408                        .unwrap(),
1409                        "/predict_edits/raw" => {
1410                            let mut buf = Vec::new();
1411                            body.read_to_end(&mut buf).await.ok();
1412                            let req = serde_json::from_slice(&buf).unwrap();
1413
1414                            let (res_tx, res_rx) = oneshot::channel();
1415                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1416                            serde_json::to_string(&res_rx.await?).unwrap()
1417                        }
1418                        "/predict_edits/reject" => {
1419                            let mut buf = Vec::new();
1420                            body.read_to_end(&mut buf).await.ok();
1421                            let req = serde_json::from_slice(&buf).unwrap();
1422
1423                            let (res_tx, res_rx) = oneshot::channel();
1424                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1425                            serde_json::to_string(&res_rx.await?).unwrap()
1426                        }
1427                        _ => {
1428                            panic!("Unexpected path: {}", uri)
1429                        }
1430                    };
1431
1432                    Ok(Response::builder().body(resp.into()).unwrap())
1433                }
1434            }
1435        });
1436
1437        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1438        client.cloud_client().set_credentials(1, "test".into());
1439
1440        language_model::init(client.clone(), cx);
1441
1442        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1443        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1444
1445        (
1446            ep_store,
1447            RequestChannels {
1448                predict: predict_req_rx,
1449                reject: reject_req_rx,
1450            },
1451        )
1452    })
1453}
1454
1455const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1456
1457#[gpui::test]
1458async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1459    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1460    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1461        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1462    });
1463
1464    let edit_preview = cx
1465        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1466        .await;
1467
1468    let prediction = EditPrediction {
1469        edits,
1470        edit_preview,
1471        buffer: buffer.clone(),
1472        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1473        id: EditPredictionId("the-id".into()),
1474        inputs: ZetaPromptInput {
1475            events: Default::default(),
1476            related_files: Default::default(),
1477            cursor_path: Path::new("").into(),
1478            cursor_excerpt: "".into(),
1479            editable_range_in_excerpt: 0..0,
1480            cursor_offset_in_excerpt: 0,
1481        },
1482        buffer_snapshotted_at: Instant::now(),
1483        response_received_at: Instant::now(),
1484    };
1485
1486    cx.update(|cx| {
1487        assert_eq!(
1488            from_completion_edits(
1489                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1490                &buffer,
1491                cx
1492            ),
1493            vec![(2..5, "REM".into()), (9..11, "".into())]
1494        );
1495
1496        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1497        assert_eq!(
1498            from_completion_edits(
1499                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1500                &buffer,
1501                cx
1502            ),
1503            vec![(2..2, "REM".into()), (6..8, "".into())]
1504        );
1505
1506        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1507        assert_eq!(
1508            from_completion_edits(
1509                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1510                &buffer,
1511                cx
1512            ),
1513            vec![(2..5, "REM".into()), (9..11, "".into())]
1514        );
1515
1516        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1517        assert_eq!(
1518            from_completion_edits(
1519                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1520                &buffer,
1521                cx
1522            ),
1523            vec![(3..3, "EM".into()), (7..9, "".into())]
1524        );
1525
1526        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1527        assert_eq!(
1528            from_completion_edits(
1529                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1530                &buffer,
1531                cx
1532            ),
1533            vec![(4..4, "M".into()), (8..10, "".into())]
1534        );
1535
1536        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1537        assert_eq!(
1538            from_completion_edits(
1539                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1540                &buffer,
1541                cx
1542            ),
1543            vec![(9..11, "".into())]
1544        );
1545
1546        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1547        assert_eq!(
1548            from_completion_edits(
1549                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1550                &buffer,
1551                cx
1552            ),
1553            vec![(4..4, "M".into()), (8..10, "".into())]
1554        );
1555
1556        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1557        assert_eq!(
1558            from_completion_edits(
1559                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1560                &buffer,
1561                cx
1562            ),
1563            vec![(4..4, "M".into())]
1564        );
1565
1566        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1567        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1568    })
1569}
1570
1571#[gpui::test]
1572async fn test_clean_up_diff(cx: &mut TestAppContext) {
1573    init_test(cx);
1574
1575    assert_eq!(
1576        apply_edit_prediction(
1577            indoc! {"
1578                    fn main() {
1579                        let word_1 = \"lorem\";
1580                        let range = word.len()..word.len();
1581                    }
1582                "},
1583            indoc! {"
1584                    <|editable_region_start|>
1585                    fn main() {
1586                        let word_1 = \"lorem\";
1587                        let range = word_1.len()..word_1.len();
1588                    }
1589
1590                    <|editable_region_end|>
1591                "},
1592            cx,
1593        )
1594        .await,
1595        indoc! {"
1596                fn main() {
1597                    let word_1 = \"lorem\";
1598                    let range = word_1.len()..word_1.len();
1599                }
1600            "},
1601    );
1602
1603    assert_eq!(
1604        apply_edit_prediction(
1605            indoc! {"
1606                    fn main() {
1607                        let story = \"the quick\"
1608                    }
1609                "},
1610            indoc! {"
1611                    <|editable_region_start|>
1612                    fn main() {
1613                        let story = \"the quick brown fox jumps over the lazy dog\";
1614                    }
1615
1616                    <|editable_region_end|>
1617                "},
1618            cx,
1619        )
1620        .await,
1621        indoc! {"
1622                fn main() {
1623                    let story = \"the quick brown fox jumps over the lazy dog\";
1624                }
1625            "},
1626    );
1627}
1628
1629#[gpui::test]
1630async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1631    init_test(cx);
1632
1633    let buffer_content = "lorem\n";
1634    let completion_response = indoc! {"
1635            ```animals.js
1636            <|start_of_file|>
1637            <|editable_region_start|>
1638            lorem
1639            ipsum
1640            <|editable_region_end|>
1641            ```"};
1642
1643    assert_eq!(
1644        apply_edit_prediction(buffer_content, completion_response, cx).await,
1645        "lorem\nipsum"
1646    );
1647}
1648
1649#[gpui::test]
1650async fn test_can_collect_data(cx: &mut TestAppContext) {
1651    init_test(cx);
1652
1653    let fs = project::FakeFs::new(cx.executor());
1654    fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1655        .await;
1656
1657    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1658    let buffer = project
1659        .update(cx, |project, cx| {
1660            project.open_local_buffer(path!("/project/src/main.rs"), cx)
1661        })
1662        .await
1663        .unwrap();
1664
1665    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1666    ep_store.update(cx, |ep_store, _cx| {
1667        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1668    });
1669
1670    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1671    assert_eq!(
1672        captured_request.lock().clone().unwrap().can_collect_data,
1673        true
1674    );
1675
1676    ep_store.update(cx, |ep_store, _cx| {
1677        ep_store.data_collection_choice = DataCollectionChoice::Disabled
1678    });
1679
1680    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1681    assert_eq!(
1682        captured_request.lock().clone().unwrap().can_collect_data,
1683        false
1684    );
1685}
1686
1687#[gpui::test]
1688async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1689    init_test(cx);
1690
1691    let fs = project::FakeFs::new(cx.executor());
1692    let project = Project::test(fs.clone(), [], cx).await;
1693
1694    let buffer = cx.new(|_cx| {
1695        Buffer::remote(
1696            language::BufferId::new(1).unwrap(),
1697            ReplicaId::new(1),
1698            language::Capability::ReadWrite,
1699            "fn main() {\n    println!(\"Hello\");\n}",
1700        )
1701    });
1702
1703    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1704    ep_store.update(cx, |ep_store, _cx| {
1705        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1706    });
1707
1708    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1709    assert_eq!(
1710        captured_request.lock().clone().unwrap().can_collect_data,
1711        false
1712    );
1713}
1714
1715#[gpui::test]
1716async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1717    init_test(cx);
1718
1719    let fs = project::FakeFs::new(cx.executor());
1720    fs.insert_tree(
1721        path!("/project"),
1722        json!({
1723            "LICENSE": BSD_0_TXT,
1724            ".env": "SECRET_KEY=secret"
1725        }),
1726    )
1727    .await;
1728
1729    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1730    let buffer = project
1731        .update(cx, |project, cx| {
1732            project.open_local_buffer("/project/.env", cx)
1733        })
1734        .await
1735        .unwrap();
1736
1737    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1738    ep_store.update(cx, |ep_store, _cx| {
1739        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1740    });
1741
1742    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1743    assert_eq!(
1744        captured_request.lock().clone().unwrap().can_collect_data,
1745        false
1746    );
1747}
1748
1749#[gpui::test]
1750async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1751    init_test(cx);
1752
1753    let fs = project::FakeFs::new(cx.executor());
1754    let project = Project::test(fs.clone(), [], cx).await;
1755    let buffer = cx.new(|cx| Buffer::local("", cx));
1756
1757    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1758    ep_store.update(cx, |ep_store, _cx| {
1759        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1760    });
1761
1762    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1763    assert_eq!(
1764        captured_request.lock().clone().unwrap().can_collect_data,
1765        false
1766    );
1767}
1768
1769#[gpui::test]
1770async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1771    init_test(cx);
1772
1773    let fs = project::FakeFs::new(cx.executor());
1774    fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1775        .await;
1776
1777    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1778    let buffer = project
1779        .update(cx, |project, cx| {
1780            project.open_local_buffer("/project/main.rs", cx)
1781        })
1782        .await
1783        .unwrap();
1784
1785    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1786    ep_store.update(cx, |ep_store, _cx| {
1787        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1788    });
1789
1790    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1791    assert_eq!(
1792        captured_request.lock().clone().unwrap().can_collect_data,
1793        false
1794    );
1795}
1796
1797#[gpui::test]
1798async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1799    init_test(cx);
1800
1801    let fs = project::FakeFs::new(cx.executor());
1802    fs.insert_tree(
1803        path!("/open_source_worktree"),
1804        json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1805    )
1806    .await;
1807    fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1808        .await;
1809
1810    let project = Project::test(
1811        fs.clone(),
1812        [
1813            path!("/open_source_worktree").as_ref(),
1814            path!("/closed_source_worktree").as_ref(),
1815        ],
1816        cx,
1817    )
1818    .await;
1819    let buffer = project
1820        .update(cx, |project, cx| {
1821            project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1822        })
1823        .await
1824        .unwrap();
1825
1826    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1827    ep_store.update(cx, |ep_store, _cx| {
1828        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1829    });
1830
1831    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1832    assert_eq!(
1833        captured_request.lock().clone().unwrap().can_collect_data,
1834        true
1835    );
1836
1837    let closed_source_file = project
1838        .update(cx, |project, cx| {
1839            let worktree2 = project
1840                .worktree_for_root_name("closed_source_worktree", cx)
1841                .unwrap();
1842            worktree2.update(cx, |worktree2, cx| {
1843                worktree2.load_file(rel_path("main.rs"), cx)
1844            })
1845        })
1846        .await
1847        .unwrap()
1848        .file;
1849
1850    buffer.update(cx, |buffer, cx| {
1851        buffer.file_updated(closed_source_file, cx);
1852    });
1853
1854    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1855    assert_eq!(
1856        captured_request.lock().clone().unwrap().can_collect_data,
1857        false
1858    );
1859}
1860
1861#[gpui::test]
1862async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1863    init_test(cx);
1864
1865    let fs = project::FakeFs::new(cx.executor());
1866    fs.insert_tree(
1867        path!("/worktree1"),
1868        json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1869    )
1870    .await;
1871    fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1872        .await;
1873
1874    let project = Project::test(
1875        fs.clone(),
1876        [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1877        cx,
1878    )
1879    .await;
1880    let buffer = project
1881        .update(cx, |project, cx| {
1882            project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1883        })
1884        .await
1885        .unwrap();
1886    let private_buffer = project
1887        .update(cx, |project, cx| {
1888            project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1889        })
1890        .await
1891        .unwrap();
1892
1893    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1894    ep_store.update(cx, |ep_store, _cx| {
1895        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1896    });
1897
1898    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1899    assert_eq!(
1900        captured_request.lock().clone().unwrap().can_collect_data,
1901        true
1902    );
1903
1904    // this has a side effect of registering the buffer to watch for edits
1905    run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1906    assert_eq!(
1907        captured_request.lock().clone().unwrap().can_collect_data,
1908        false
1909    );
1910
1911    private_buffer.update(cx, |private_buffer, cx| {
1912        private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1913    });
1914
1915    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1916    assert_eq!(
1917        captured_request.lock().clone().unwrap().can_collect_data,
1918        false
1919    );
1920
1921    // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1922    // included
1923    buffer.update(cx, |buffer, cx| {
1924        buffer.edit(
1925            [(
1926                0..0,
1927                " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1928            )],
1929            None,
1930            cx,
1931        );
1932    });
1933
1934    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1935    assert_eq!(
1936        captured_request.lock().clone().unwrap().can_collect_data,
1937        true
1938    );
1939}
1940
1941fn init_test(cx: &mut TestAppContext) {
1942    cx.update(|cx| {
1943        let settings_store = SettingsStore::test(cx);
1944        cx.set_global(settings_store);
1945    });
1946}
1947
1948async fn apply_edit_prediction(
1949    buffer_content: &str,
1950    completion_response: &str,
1951    cx: &mut TestAppContext,
1952) -> String {
1953    let fs = project::FakeFs::new(cx.executor());
1954    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1955    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1956    let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1957    *response.lock() = completion_response.to_string();
1958    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1959    buffer.update(cx, |buffer, cx| {
1960        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1961    });
1962    buffer.read_with(cx, |buffer, _| buffer.text())
1963}
1964
1965async fn run_edit_prediction(
1966    buffer: &Entity<Buffer>,
1967    project: &Entity<Project>,
1968    ep_store: &Entity<EditPredictionStore>,
1969    cx: &mut TestAppContext,
1970) -> EditPrediction {
1971    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1972    ep_store.update(cx, |ep_store, cx| {
1973        ep_store.register_buffer(buffer, &project, cx)
1974    });
1975    cx.background_executor.run_until_parked();
1976    let prediction_task = ep_store.update(cx, |ep_store, cx| {
1977        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
1978    });
1979    prediction_task.await.unwrap().unwrap().prediction.unwrap()
1980}
1981
1982async fn make_test_ep_store(
1983    project: &Entity<Project>,
1984    cx: &mut TestAppContext,
1985) -> (
1986    Entity<EditPredictionStore>,
1987    Arc<Mutex<Option<PredictEditsBody>>>,
1988    Arc<Mutex<String>>,
1989) {
1990    let default_response = indoc! {"
1991            ```main.rs
1992            <|start_of_file|>
1993            <|editable_region_start|>
1994            hello world
1995            <|editable_region_end|>
1996            ```"
1997    };
1998    let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
1999    let completion_response: Arc<Mutex<String>> =
2000        Arc::new(Mutex::new(default_response.to_string()));
2001    let http_client = FakeHttpClient::create({
2002        let captured_request = captured_request.clone();
2003        let completion_response = completion_response.clone();
2004        let mut next_request_id = 0;
2005        move |req| {
2006            let captured_request = captured_request.clone();
2007            let completion_response = completion_response.clone();
2008            async move {
2009                match (req.method(), req.uri().path()) {
2010                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2011                        .status(200)
2012                        .body(
2013                            serde_json::to_string(&CreateLlmTokenResponse {
2014                                token: LlmToken("the-llm-token".to_string()),
2015                            })
2016                            .unwrap()
2017                            .into(),
2018                        )
2019                        .unwrap()),
2020                    (&Method::POST, "/predict_edits/v2") => {
2021                        let mut request_body = String::new();
2022                        req.into_body().read_to_string(&mut request_body).await?;
2023                        *captured_request.lock() =
2024                            Some(serde_json::from_str(&request_body).unwrap());
2025                        next_request_id += 1;
2026                        Ok(http_client::Response::builder()
2027                            .status(200)
2028                            .body(
2029                                serde_json::to_string(&PredictEditsResponse {
2030                                    request_id: format!("request-{next_request_id}"),
2031                                    output_excerpt: completion_response.lock().clone(),
2032                                })
2033                                .unwrap()
2034                                .into(),
2035                            )
2036                            .unwrap())
2037                    }
2038                    _ => Ok(http_client::Response::builder()
2039                        .status(404)
2040                        .body("Not Found".into())
2041                        .unwrap()),
2042                }
2043            }
2044        }
2045    });
2046
2047    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2048    cx.update(|cx| {
2049        RefreshLlmTokenListener::register(client.clone(), cx);
2050    });
2051    let _server = FakeServer::for_client(42, &client, cx).await;
2052
2053    let ep_store = cx.new(|cx| {
2054        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2055        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2056
2057        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2058        for worktree in worktrees {
2059            let worktree_id = worktree.read(cx).id();
2060            ep_store
2061                .get_or_init_project(project, cx)
2062                .license_detection_watchers
2063                .entry(worktree_id)
2064                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2065        }
2066
2067        ep_store
2068    });
2069
2070    (ep_store, captured_request, completion_response)
2071}
2072
2073fn to_completion_edits(
2074    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2075    buffer: &Entity<Buffer>,
2076    cx: &App,
2077) -> Vec<(Range<Anchor>, Arc<str>)> {
2078    let buffer = buffer.read(cx);
2079    iterator
2080        .into_iter()
2081        .map(|(range, text)| {
2082            (
2083                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2084                text,
2085            )
2086        })
2087        .collect()
2088}
2089
2090fn from_completion_edits(
2091    editor_edits: &[(Range<Anchor>, Arc<str>)],
2092    buffer: &Entity<Buffer>,
2093    cx: &App,
2094) -> Vec<(Range<usize>, Arc<str>)> {
2095    let buffer = buffer.read(cx);
2096    editor_edits
2097        .iter()
2098        .map(|(range, text)| {
2099            (
2100                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2101                text.clone(),
2102            )
2103        })
2104        .collect()
2105}
2106
2107#[gpui::test]
2108async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2109    init_test(cx);
2110
2111    let fs = FakeFs::new(cx.executor());
2112    fs.insert_tree(
2113        "/project",
2114        serde_json::json!({
2115            "main.rs": "fn main() {\n    \n}\n"
2116        }),
2117    )
2118    .await;
2119
2120    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2121
2122    let http_client = FakeHttpClient::create(|_req| async move {
2123        Ok(gpui::http_client::Response::builder()
2124            .status(401)
2125            .body("Unauthorized".into())
2126            .unwrap())
2127    });
2128
2129    let client =
2130        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2131    cx.update(|cx| {
2132        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2133    });
2134
2135    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2136
2137    let buffer = project
2138        .update(cx, |project, cx| {
2139            let path = project
2140                .find_project_path(path!("/project/main.rs"), cx)
2141                .unwrap();
2142            project.open_buffer(path, cx)
2143        })
2144        .await
2145        .unwrap();
2146
2147    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2148    ep_store.update(cx, |ep_store, cx| {
2149        ep_store.register_buffer(&buffer, &project, cx)
2150    });
2151    cx.background_executor.run_until_parked();
2152
2153    let completion_task = ep_store.update(cx, |ep_store, cx| {
2154        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2155        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2156    });
2157
2158    let result = completion_task.await;
2159    assert!(
2160        result.is_err(),
2161        "Without authentication and without custom URL, prediction should fail"
2162    );
2163}
2164
2165#[gpui::test]
2166async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2167    init_test(cx);
2168
2169    let fs = FakeFs::new(cx.executor());
2170    fs.insert_tree(
2171        "/project",
2172        serde_json::json!({
2173            "main.rs": "fn main() {\n    \n}\n"
2174        }),
2175    )
2176    .await;
2177
2178    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2179
2180    let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2181    let predict_called_clone = predict_called.clone();
2182
2183    let http_client = FakeHttpClient::create({
2184        move |req| {
2185            let uri = req.uri().path().to_string();
2186            let predict_called = predict_called_clone.clone();
2187            async move {
2188                if uri.contains("predict") {
2189                    predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2190                    Ok(gpui::http_client::Response::builder()
2191                        .body(
2192                            serde_json::to_string(&open_ai::Response {
2193                                id: "test-123".to_string(),
2194                                object: "chat.completion".to_string(),
2195                                created: 0,
2196                                model: "test".to_string(),
2197                                usage: open_ai::Usage {
2198                                    prompt_tokens: 0,
2199                                    completion_tokens: 0,
2200                                    total_tokens: 0,
2201                                },
2202                                choices: vec![open_ai::Choice {
2203                                    index: 0,
2204                                    message: open_ai::RequestMessage::Assistant {
2205                                        content: Some(open_ai::MessageContent::Plain(
2206                                            indoc! {"
2207                                                ```main.rs
2208                                                <|start_of_file|>
2209                                                <|editable_region_start|>
2210                                                fn main() {
2211                                                    println!(\"Hello, world!\");
2212                                                }
2213                                                <|editable_region_end|>
2214                                                ```
2215                                            "}
2216                                            .to_string(),
2217                                        )),
2218                                        tool_calls: vec![],
2219                                    },
2220                                    finish_reason: Some("stop".to_string()),
2221                                }],
2222                            })
2223                            .unwrap()
2224                            .into(),
2225                        )
2226                        .unwrap())
2227                } else {
2228                    Ok(gpui::http_client::Response::builder()
2229                        .status(401)
2230                        .body("Unauthorized".into())
2231                        .unwrap())
2232                }
2233            }
2234        }
2235    });
2236
2237    let client =
2238        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2239    cx.update(|cx| {
2240        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2241    });
2242
2243    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2244
2245    let buffer = project
2246        .update(cx, |project, cx| {
2247            let path = project
2248                .find_project_path(path!("/project/main.rs"), cx)
2249                .unwrap();
2250            project.open_buffer(path, cx)
2251        })
2252        .await
2253        .unwrap();
2254
2255    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2256    ep_store.update(cx, |ep_store, cx| {
2257        ep_store.register_buffer(&buffer, &project, cx)
2258    });
2259    cx.background_executor.run_until_parked();
2260
2261    let completion_task = ep_store.update(cx, |ep_store, cx| {
2262        ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2263        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2264        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2265    });
2266
2267    let _ = completion_task.await;
2268
2269    assert!(
2270        predict_called.load(std::sync::atomic::Ordering::SeqCst),
2271        "With custom URL, predict endpoint should be called even without authentication"
2272    );
2273}
2274
2275#[gpui::test]
2276fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2277    let buffer = cx.new(|cx| {
2278        Buffer::local(
2279            indoc! {"
2280                zero
2281                one
2282                two
2283                three
2284                four
2285                five
2286                six
2287                seven
2288                eight
2289                nine
2290                ten
2291                eleven
2292                twelve
2293                thirteen
2294                fourteen
2295                fifteen
2296                sixteen
2297                seventeen
2298                eighteen
2299                nineteen
2300                twenty
2301                twenty-one
2302                twenty-two
2303                twenty-three
2304                twenty-four
2305            "},
2306            cx,
2307        )
2308    });
2309
2310    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2311
2312    buffer.update(cx, |buffer, cx| {
2313        let point = Point::new(12, 0);
2314        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2315        let point = Point::new(8, 0);
2316        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2317    });
2318
2319    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2320
2321    let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2322
2323    assert_eq!(
2324        diff,
2325        indoc! {"
2326            @@ -6,10 +6,12 @@
2327             five
2328             six
2329             seven
2330            +FIRST INSERTION
2331             eight
2332             nine
2333             ten
2334             eleven
2335            +SECOND INSERTION
2336             twelve
2337             thirteen
2338             fourteen
2339            "}
2340    );
2341}
2342
2343#[ctor::ctor]
2344fn init_logger() {
2345    zlog::init_test();
2346}