edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
   3use client::{UserStore, test::FakeServer};
   4use clock::{FakeSystemClock, ReplicaId};
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
   8    RejectEditPredictionsBody,
   9    predict_edits_v3::{
  10        RawCompletionChoice, RawCompletionRequest, RawCompletionResponse, RawCompletionUsage,
  11    },
  12};
  13use futures::{
  14    AsyncReadExt, StreamExt,
  15    channel::{mpsc, oneshot},
  16};
  17use gpui::{
  18    Entity, TestAppContext,
  19    http_client::{FakeHttpClient, Response},
  20};
  21use indoc::indoc;
  22use language::Point;
  23use lsp::LanguageServerId;
  24use parking_lot::Mutex;
  25use pretty_assertions::{assert_eq, assert_matches};
  26use project::{FakeFs, Project};
  27use serde_json::json;
  28use settings::SettingsStore;
  29use std::{path::Path, sync::Arc, time::Duration};
  30use util::{path, rel_path::rel_path};
  31use uuid::Uuid;
  32use zeta_prompt::ZetaPromptInput;
  33
  34use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  35
  36#[gpui::test]
  37async fn test_current_state(cx: &mut TestAppContext) {
  38    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  39    let fs = FakeFs::new(cx.executor());
  40    fs.insert_tree(
  41        "/root",
  42        json!({
  43            "1.txt": "Hello!\nHow\nBye\n",
  44            "2.txt": "Hola!\nComo\nAdios\n"
  45        }),
  46    )
  47    .await;
  48    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  49
  50    let buffer1 = project
  51        .update(cx, |project, cx| {
  52            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  53            project.set_active_path(Some(path.clone()), cx);
  54            project.open_buffer(path, cx)
  55        })
  56        .await
  57        .unwrap();
  58    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  59    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  60
  61    ep_store.update(cx, |ep_store, cx| {
  62        ep_store.register_project(&project, cx);
  63        ep_store.register_buffer(&buffer1, &project, cx);
  64    });
  65
  66    // Prediction for current file
  67
  68    ep_store.update(cx, |ep_store, cx| {
  69        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  70    });
  71    let (request, respond_tx) = requests.predict.next().await.unwrap();
  72
  73    respond_tx
  74        .send(model_response(
  75            request,
  76            indoc! {r"
  77                --- a/root/1.txt
  78                +++ b/root/1.txt
  79                @@ ... @@
  80                 Hello!
  81                -How
  82                +How are you?
  83                 Bye
  84            "},
  85        ))
  86        .unwrap();
  87
  88    cx.run_until_parked();
  89
  90    ep_store.update(cx, |ep_store, cx| {
  91        let prediction = ep_store
  92            .prediction_at(&buffer1, None, &project, cx)
  93            .unwrap();
  94        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  95    });
  96
  97    ep_store.update(cx, |ep_store, _cx| {
  98        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
  99    });
 100
 101    // Prediction for diagnostic in another file
 102
 103    let diagnostic = lsp::Diagnostic {
 104        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 105        severity: Some(lsp::DiagnosticSeverity::ERROR),
 106        message: "Sentence is incomplete".to_string(),
 107        ..Default::default()
 108    };
 109
 110    project.update(cx, |project, cx| {
 111        project.lsp_store().update(cx, |lsp_store, cx| {
 112            lsp_store
 113                .update_diagnostics(
 114                    LanguageServerId(0),
 115                    lsp::PublishDiagnosticsParams {
 116                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 117                        diagnostics: vec![diagnostic],
 118                        version: None,
 119                    },
 120                    None,
 121                    language::DiagnosticSourceKind::Pushed,
 122                    &[],
 123                    cx,
 124                )
 125                .unwrap();
 126        });
 127    });
 128
 129    let (request, respond_tx) = requests.predict.next().await.unwrap();
 130    respond_tx
 131        .send(model_response(
 132            request,
 133            indoc! {r#"
 134                --- a/root/2.txt
 135                +++ b/root/2.txt
 136                @@ ... @@
 137                 Hola!
 138                -Como
 139                +Como estas?
 140                 Adios
 141            "#},
 142        ))
 143        .unwrap();
 144    cx.run_until_parked();
 145
 146    ep_store.update(cx, |ep_store, cx| {
 147        let prediction = ep_store
 148            .prediction_at(&buffer1, None, &project, cx)
 149            .unwrap();
 150        assert_matches!(
 151            prediction,
 152            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 153        );
 154    });
 155
 156    let buffer2 = project
 157        .update(cx, |project, cx| {
 158            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 159            project.open_buffer(path, cx)
 160        })
 161        .await
 162        .unwrap();
 163
 164    ep_store.update(cx, |ep_store, cx| {
 165        let prediction = ep_store
 166            .prediction_at(&buffer2, None, &project, cx)
 167            .unwrap();
 168        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 169    });
 170}
 171
 172#[gpui::test]
 173async fn test_simple_request(cx: &mut TestAppContext) {
 174    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 175    let fs = FakeFs::new(cx.executor());
 176    fs.insert_tree(
 177        "/root",
 178        json!({
 179            "foo.md":  "Hello!\nHow\nBye\n"
 180        }),
 181    )
 182    .await;
 183    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 184
 185    let buffer = project
 186        .update(cx, |project, cx| {
 187            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 188            project.open_buffer(path, cx)
 189        })
 190        .await
 191        .unwrap();
 192    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 193    let position = snapshot.anchor_before(language::Point::new(1, 3));
 194
 195    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 196        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 197    });
 198
 199    let (request, respond_tx) = requests.predict.next().await.unwrap();
 200
 201    // TODO Put back when we have a structured request again
 202    // assert_eq!(
 203    //     request.excerpt_path.as_ref(),
 204    //     Path::new(path!("root/foo.md"))
 205    // );
 206    // assert_eq!(
 207    //     request.cursor_point,
 208    //     Point {
 209    //         line: Line(1),
 210    //         column: 3
 211    //     }
 212    // );
 213
 214    respond_tx
 215        .send(model_response(
 216            request,
 217            indoc! { r"
 218                --- a/root/foo.md
 219                +++ b/root/foo.md
 220                @@ ... @@
 221                 Hello!
 222                -How
 223                +How are you?
 224                 Bye
 225            "},
 226        ))
 227        .unwrap();
 228
 229    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 230
 231    assert_eq!(prediction.edits.len(), 1);
 232    assert_eq!(
 233        prediction.edits[0].0.to_point(&snapshot).start,
 234        language::Point::new(1, 3)
 235    );
 236    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 237}
 238
 239#[gpui::test]
 240async fn test_request_events(cx: &mut TestAppContext) {
 241    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 242    let fs = FakeFs::new(cx.executor());
 243    fs.insert_tree(
 244        "/root",
 245        json!({
 246            "foo.md": "Hello!\n\nBye\n"
 247        }),
 248    )
 249    .await;
 250    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 251
 252    let buffer = project
 253        .update(cx, |project, cx| {
 254            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 255            project.open_buffer(path, cx)
 256        })
 257        .await
 258        .unwrap();
 259
 260    ep_store.update(cx, |ep_store, cx| {
 261        ep_store.register_buffer(&buffer, &project, cx);
 262    });
 263
 264    buffer.update(cx, |buffer, cx| {
 265        buffer.edit(vec![(7..7, "How")], None, cx);
 266    });
 267
 268    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 269    let position = snapshot.anchor_before(language::Point::new(1, 3));
 270
 271    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 272        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 273    });
 274
 275    let (request, respond_tx) = requests.predict.next().await.unwrap();
 276
 277    let prompt = prompt_from_request(&request);
 278    assert!(
 279        prompt.contains(indoc! {"
 280        --- a/root/foo.md
 281        +++ b/root/foo.md
 282        @@ -1,3 +1,3 @@
 283         Hello!
 284        -
 285        +How
 286         Bye
 287    "}),
 288        "{prompt}"
 289    );
 290
 291    respond_tx
 292        .send(model_response(
 293            request,
 294            indoc! {r#"
 295                --- a/root/foo.md
 296                +++ b/root/foo.md
 297                @@ ... @@
 298                 Hello!
 299                -How
 300                +How are you?
 301                 Bye
 302        "#},
 303        ))
 304        .unwrap();
 305
 306    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 307
 308    assert_eq!(prediction.edits.len(), 1);
 309    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 310}
 311
 312#[gpui::test]
 313async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 314    let (ep_store, _requests) = init_test_with_fake_client(cx);
 315    let fs = FakeFs::new(cx.executor());
 316    fs.insert_tree(
 317        "/root",
 318        json!({
 319            "foo.md": "Hello!\n\nBye\n"
 320        }),
 321    )
 322    .await;
 323    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 324
 325    let buffer = project
 326        .update(cx, |project, cx| {
 327            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 328            project.open_buffer(path, cx)
 329        })
 330        .await
 331        .unwrap();
 332
 333    ep_store.update(cx, |ep_store, cx| {
 334        ep_store.register_buffer(&buffer, &project, cx);
 335    });
 336
 337    // First burst: insert "How"
 338    buffer.update(cx, |buffer, cx| {
 339        buffer.edit(vec![(7..7, "How")], None, cx);
 340    });
 341
 342    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 343    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 344    cx.run_until_parked();
 345
 346    // Second burst: append " are you?" immediately after "How" on the same line.
 347    //
 348    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 349    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 350    buffer.update(cx, |buffer, cx| {
 351        buffer.edit(vec![(10..10, " are you?")], None, cx);
 352    });
 353
 354    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 355    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 356    buffer.update(cx, |buffer, cx| {
 357        buffer.edit(vec![(19..19, "!")], None, cx);
 358    });
 359
 360    // Without time-based splitting, there is one event.
 361    let events = ep_store.update(cx, |ep_store, cx| {
 362        ep_store.edit_history_for_project(&project, cx)
 363    });
 364    assert_eq!(events.len(), 1);
 365    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 366    assert_eq!(
 367        diff.as_str(),
 368        indoc! {"
 369            @@ -1,3 +1,3 @@
 370             Hello!
 371            -
 372            +How are you?!
 373             Bye
 374        "}
 375    );
 376
 377    // With time-based splitting, there are two distinct events.
 378    let events = ep_store.update(cx, |ep_store, cx| {
 379        ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
 380    });
 381    assert_eq!(events.len(), 2);
 382    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 383    assert_eq!(
 384        diff.as_str(),
 385        indoc! {"
 386            @@ -1,3 +1,3 @@
 387             Hello!
 388            -
 389            +How
 390             Bye
 391        "}
 392    );
 393
 394    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 395    assert_eq!(
 396        diff.as_str(),
 397        indoc! {"
 398            @@ -1,3 +1,3 @@
 399             Hello!
 400            -How
 401            +How are you?!
 402             Bye
 403        "}
 404    );
 405}
 406
 407#[gpui::test]
 408async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
 409    let (ep_store, _requests) = init_test_with_fake_client(cx);
 410    let fs = FakeFs::new(cx.executor());
 411
 412    // Create a file with 30 lines to test line-based coalescing
 413    let content = (1..=30)
 414        .map(|i| format!("Line {}\n", i))
 415        .collect::<String>();
 416    fs.insert_tree(
 417        "/root",
 418        json!({
 419            "foo.md": content
 420        }),
 421    )
 422    .await;
 423    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 424
 425    let buffer = project
 426        .update(cx, |project, cx| {
 427            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 428            project.open_buffer(path, cx)
 429        })
 430        .await
 431        .unwrap();
 432
 433    ep_store.update(cx, |ep_store, cx| {
 434        ep_store.register_buffer(&buffer, &project, cx);
 435    });
 436
 437    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 438    buffer.update(cx, |buffer, cx| {
 439        let start = Point::new(10, 0).to_offset(buffer);
 440        let end = Point::new(13, 0).to_offset(buffer);
 441        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 442    });
 443
 444    let events = ep_store.update(cx, |ep_store, cx| {
 445        ep_store.edit_history_for_project(&project, cx)
 446    });
 447    assert_eq!(
 448        render_events(&events),
 449        indoc! {"
 450            @@ -8,9 +8,8 @@
 451             Line 8
 452             Line 9
 453             Line 10
 454            -Line 11
 455            -Line 12
 456            -Line 13
 457            +Middle A
 458            +Middle B
 459             Line 14
 460             Line 15
 461             Line 16
 462        "},
 463        "After first edit"
 464    );
 465
 466    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 467    // This tests that coalescing considers the START of the existing range
 468    buffer.update(cx, |buffer, cx| {
 469        let offset = Point::new(5, 0).to_offset(buffer);
 470        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 471    });
 472
 473    let events = ep_store.update(cx, |ep_store, cx| {
 474        ep_store.edit_history_for_project(&project, cx)
 475    });
 476    assert_eq!(
 477        render_events(&events),
 478        indoc! {"
 479            @@ -3,14 +3,14 @@
 480             Line 3
 481             Line 4
 482             Line 5
 483            +Above
 484             Line 6
 485             Line 7
 486             Line 8
 487             Line 9
 488             Line 10
 489            -Line 11
 490            -Line 12
 491            -Line 13
 492            +Middle A
 493            +Middle B
 494             Line 14
 495             Line 15
 496             Line 16
 497        "},
 498        "After inserting above (should coalesce)"
 499    );
 500
 501    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 502    // This tests that coalescing considers the END of the existing range
 503    buffer.update(cx, |buffer, cx| {
 504        let offset = Point::new(14, 0).to_offset(buffer);
 505        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 506    });
 507
 508    let events = ep_store.update(cx, |ep_store, cx| {
 509        ep_store.edit_history_for_project(&project, cx)
 510    });
 511    assert_eq!(
 512        render_events(&events),
 513        indoc! {"
 514            @@ -3,15 +3,16 @@
 515             Line 3
 516             Line 4
 517             Line 5
 518            +Above
 519             Line 6
 520             Line 7
 521             Line 8
 522             Line 9
 523             Line 10
 524            -Line 11
 525            -Line 12
 526            -Line 13
 527            +Middle A
 528            +Middle B
 529             Line 14
 530            +Below
 531             Line 15
 532             Line 16
 533             Line 17
 534        "},
 535        "After inserting below (should coalesce)"
 536    );
 537
 538    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 539    // This should NOT coalesce - creates a new event
 540    buffer.update(cx, |buffer, cx| {
 541        let offset = Point::new(25, 0).to_offset(buffer);
 542        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 543    });
 544
 545    let events = ep_store.update(cx, |ep_store, cx| {
 546        ep_store.edit_history_for_project(&project, cx)
 547    });
 548    assert_eq!(
 549        render_events(&events),
 550        indoc! {"
 551            @@ -3,15 +3,16 @@
 552             Line 3
 553             Line 4
 554             Line 5
 555            +Above
 556             Line 6
 557             Line 7
 558             Line 8
 559             Line 9
 560             Line 10
 561            -Line 11
 562            -Line 12
 563            -Line 13
 564            +Middle A
 565            +Middle B
 566             Line 14
 567            +Below
 568             Line 15
 569             Line 16
 570             Line 17
 571
 572            ---
 573            @@ -23,6 +23,7 @@
 574             Line 22
 575             Line 23
 576             Line 24
 577            +Far below
 578             Line 25
 579             Line 26
 580             Line 27
 581        "},
 582        "After inserting far below (should NOT coalesce)"
 583    );
 584}
 585
 586fn render_events(events: &[StoredEvent]) -> String {
 587    events
 588        .iter()
 589        .map(|e| {
 590            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 591            diff.as_str()
 592        })
 593        .collect::<Vec<_>>()
 594        .join("\n---\n")
 595}
 596
 597#[gpui::test]
 598async fn test_empty_prediction(cx: &mut TestAppContext) {
 599    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 600    let fs = FakeFs::new(cx.executor());
 601    fs.insert_tree(
 602        "/root",
 603        json!({
 604            "foo.md":  "Hello!\nHow\nBye\n"
 605        }),
 606    )
 607    .await;
 608    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 609
 610    let buffer = project
 611        .update(cx, |project, cx| {
 612            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 613            project.open_buffer(path, cx)
 614        })
 615        .await
 616        .unwrap();
 617    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 618    let position = snapshot.anchor_before(language::Point::new(1, 3));
 619
 620    ep_store.update(cx, |ep_store, cx| {
 621        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 622    });
 623
 624    let (request, respond_tx) = requests.predict.next().await.unwrap();
 625    let response = model_response(request, "");
 626    let id = response.id.clone();
 627    respond_tx.send(response).unwrap();
 628
 629    cx.run_until_parked();
 630
 631    ep_store.update(cx, |ep_store, cx| {
 632        assert!(
 633            ep_store
 634                .prediction_at(&buffer, None, &project, cx)
 635                .is_none()
 636        );
 637    });
 638
 639    // prediction is reported as rejected
 640    let (reject_request, _) = requests.reject.next().await.unwrap();
 641
 642    assert_eq!(
 643        &reject_request.rejections,
 644        &[EditPredictionRejection {
 645            request_id: id,
 646            reason: EditPredictionRejectReason::Empty,
 647            was_shown: false
 648        }]
 649    );
 650}
 651
 652#[gpui::test]
 653async fn test_interpolated_empty(cx: &mut TestAppContext) {
 654    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 655    let fs = FakeFs::new(cx.executor());
 656    fs.insert_tree(
 657        "/root",
 658        json!({
 659            "foo.md":  "Hello!\nHow\nBye\n"
 660        }),
 661    )
 662    .await;
 663    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 664
 665    let buffer = project
 666        .update(cx, |project, cx| {
 667            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 668            project.open_buffer(path, cx)
 669        })
 670        .await
 671        .unwrap();
 672    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 673    let position = snapshot.anchor_before(language::Point::new(1, 3));
 674
 675    ep_store.update(cx, |ep_store, cx| {
 676        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 677    });
 678
 679    let (request, respond_tx) = requests.predict.next().await.unwrap();
 680
 681    buffer.update(cx, |buffer, cx| {
 682        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 683    });
 684
 685    let response = model_response(request, SIMPLE_DIFF);
 686    let id = response.id.clone();
 687    respond_tx.send(response).unwrap();
 688
 689    cx.run_until_parked();
 690
 691    ep_store.update(cx, |ep_store, cx| {
 692        assert!(
 693            ep_store
 694                .prediction_at(&buffer, None, &project, cx)
 695                .is_none()
 696        );
 697    });
 698
 699    // prediction is reported as rejected
 700    let (reject_request, _) = requests.reject.next().await.unwrap();
 701
 702    assert_eq!(
 703        &reject_request.rejections,
 704        &[EditPredictionRejection {
 705            request_id: id,
 706            reason: EditPredictionRejectReason::InterpolatedEmpty,
 707            was_shown: false
 708        }]
 709    );
 710}
 711
 712const SIMPLE_DIFF: &str = indoc! { r"
 713    --- a/root/foo.md
 714    +++ b/root/foo.md
 715    @@ ... @@
 716     Hello!
 717    -How
 718    +How are you?
 719     Bye
 720"};
 721
 722#[gpui::test]
 723async fn test_replace_current(cx: &mut TestAppContext) {
 724    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 725    let fs = FakeFs::new(cx.executor());
 726    fs.insert_tree(
 727        "/root",
 728        json!({
 729            "foo.md":  "Hello!\nHow\nBye\n"
 730        }),
 731    )
 732    .await;
 733    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 734
 735    let buffer = project
 736        .update(cx, |project, cx| {
 737            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 738            project.open_buffer(path, cx)
 739        })
 740        .await
 741        .unwrap();
 742    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 743    let position = snapshot.anchor_before(language::Point::new(1, 3));
 744
 745    ep_store.update(cx, |ep_store, cx| {
 746        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 747    });
 748
 749    let (request, respond_tx) = requests.predict.next().await.unwrap();
 750    let first_response = model_response(request, SIMPLE_DIFF);
 751    let first_id = first_response.id.clone();
 752    respond_tx.send(first_response).unwrap();
 753
 754    cx.run_until_parked();
 755
 756    ep_store.update(cx, |ep_store, cx| {
 757        assert_eq!(
 758            ep_store
 759                .prediction_at(&buffer, None, &project, cx)
 760                .unwrap()
 761                .id
 762                .0,
 763            first_id
 764        );
 765    });
 766
 767    // a second request is triggered
 768    ep_store.update(cx, |ep_store, cx| {
 769        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 770    });
 771
 772    let (request, respond_tx) = requests.predict.next().await.unwrap();
 773    let second_response = model_response(request, SIMPLE_DIFF);
 774    let second_id = second_response.id.clone();
 775    respond_tx.send(second_response).unwrap();
 776
 777    cx.run_until_parked();
 778
 779    ep_store.update(cx, |ep_store, cx| {
 780        // second replaces first
 781        assert_eq!(
 782            ep_store
 783                .prediction_at(&buffer, None, &project, cx)
 784                .unwrap()
 785                .id
 786                .0,
 787            second_id
 788        );
 789    });
 790
 791    // first is reported as replaced
 792    let (reject_request, _) = requests.reject.next().await.unwrap();
 793
 794    assert_eq!(
 795        &reject_request.rejections,
 796        &[EditPredictionRejection {
 797            request_id: first_id,
 798            reason: EditPredictionRejectReason::Replaced,
 799            was_shown: false
 800        }]
 801    );
 802}
 803
 804#[gpui::test]
 805async fn test_current_preferred(cx: &mut TestAppContext) {
 806    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 807    let fs = FakeFs::new(cx.executor());
 808    fs.insert_tree(
 809        "/root",
 810        json!({
 811            "foo.md":  "Hello!\nHow\nBye\n"
 812        }),
 813    )
 814    .await;
 815    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 816
 817    let buffer = project
 818        .update(cx, |project, cx| {
 819            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 820            project.open_buffer(path, cx)
 821        })
 822        .await
 823        .unwrap();
 824    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 825    let position = snapshot.anchor_before(language::Point::new(1, 3));
 826
 827    ep_store.update(cx, |ep_store, cx| {
 828        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 829    });
 830
 831    let (request, respond_tx) = requests.predict.next().await.unwrap();
 832    let first_response = model_response(request, SIMPLE_DIFF);
 833    let first_id = first_response.id.clone();
 834    respond_tx.send(first_response).unwrap();
 835
 836    cx.run_until_parked();
 837
 838    ep_store.update(cx, |ep_store, cx| {
 839        assert_eq!(
 840            ep_store
 841                .prediction_at(&buffer, None, &project, cx)
 842                .unwrap()
 843                .id
 844                .0,
 845            first_id
 846        );
 847    });
 848
 849    // a second request is triggered
 850    ep_store.update(cx, |ep_store, cx| {
 851        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 852    });
 853
 854    let (request, respond_tx) = requests.predict.next().await.unwrap();
 855    // worse than current prediction
 856    let second_response = model_response(
 857        request,
 858        indoc! { r"
 859            --- a/root/foo.md
 860            +++ b/root/foo.md
 861            @@ ... @@
 862             Hello!
 863            -How
 864            +How are
 865             Bye
 866        "},
 867    );
 868    let second_id = second_response.id.clone();
 869    respond_tx.send(second_response).unwrap();
 870
 871    cx.run_until_parked();
 872
 873    ep_store.update(cx, |ep_store, cx| {
 874        // first is preferred over second
 875        assert_eq!(
 876            ep_store
 877                .prediction_at(&buffer, None, &project, cx)
 878                .unwrap()
 879                .id
 880                .0,
 881            first_id
 882        );
 883    });
 884
 885    // second is reported as rejected
 886    let (reject_request, _) = requests.reject.next().await.unwrap();
 887
 888    assert_eq!(
 889        &reject_request.rejections,
 890        &[EditPredictionRejection {
 891            request_id: second_id,
 892            reason: EditPredictionRejectReason::CurrentPreferred,
 893            was_shown: false
 894        }]
 895    );
 896}
 897
 898#[gpui::test]
 899async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
 900    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 901    let fs = FakeFs::new(cx.executor());
 902    fs.insert_tree(
 903        "/root",
 904        json!({
 905            "foo.md":  "Hello!\nHow\nBye\n"
 906        }),
 907    )
 908    .await;
 909    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 910
 911    let buffer = project
 912        .update(cx, |project, cx| {
 913            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 914            project.open_buffer(path, cx)
 915        })
 916        .await
 917        .unwrap();
 918    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 919    let position = snapshot.anchor_before(language::Point::new(1, 3));
 920
 921    // start two refresh tasks
 922    ep_store.update(cx, |ep_store, cx| {
 923        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 924    });
 925
 926    let (request1, respond_first) = requests.predict.next().await.unwrap();
 927
 928    ep_store.update(cx, |ep_store, cx| {
 929        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 930    });
 931
 932    let (request, respond_second) = requests.predict.next().await.unwrap();
 933
 934    // wait for throttle
 935    cx.run_until_parked();
 936
 937    // second responds first
 938    let second_response = model_response(request, SIMPLE_DIFF);
 939    let second_id = second_response.id.clone();
 940    respond_second.send(second_response).unwrap();
 941
 942    cx.run_until_parked();
 943
 944    ep_store.update(cx, |ep_store, cx| {
 945        // current prediction is second
 946        assert_eq!(
 947            ep_store
 948                .prediction_at(&buffer, None, &project, cx)
 949                .unwrap()
 950                .id
 951                .0,
 952            second_id
 953        );
 954    });
 955
 956    let first_response = model_response(request1, SIMPLE_DIFF);
 957    let first_id = first_response.id.clone();
 958    respond_first.send(first_response).unwrap();
 959
 960    cx.run_until_parked();
 961
 962    ep_store.update(cx, |ep_store, cx| {
 963        // current prediction is still second, since first was cancelled
 964        assert_eq!(
 965            ep_store
 966                .prediction_at(&buffer, None, &project, cx)
 967                .unwrap()
 968                .id
 969                .0,
 970            second_id
 971        );
 972    });
 973
 974    // first is reported as rejected
 975    let (reject_request, _) = requests.reject.next().await.unwrap();
 976
 977    cx.run_until_parked();
 978
 979    assert_eq!(
 980        &reject_request.rejections,
 981        &[EditPredictionRejection {
 982            request_id: first_id,
 983            reason: EditPredictionRejectReason::Canceled,
 984            was_shown: false
 985        }]
 986    );
 987}
 988
 989#[gpui::test]
 990async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
 991    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 992    let fs = FakeFs::new(cx.executor());
 993    fs.insert_tree(
 994        "/root",
 995        json!({
 996            "foo.md":  "Hello!\nHow\nBye\n"
 997        }),
 998    )
 999    .await;
1000    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1001
1002    let buffer = project
1003        .update(cx, |project, cx| {
1004            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1005            project.open_buffer(path, cx)
1006        })
1007        .await
1008        .unwrap();
1009    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1010    let position = snapshot.anchor_before(language::Point::new(1, 3));
1011
1012    // start two refresh tasks
1013    ep_store.update(cx, |ep_store, cx| {
1014        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1015    });
1016
1017    let (request1, respond_first) = requests.predict.next().await.unwrap();
1018
1019    ep_store.update(cx, |ep_store, cx| {
1020        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1021    });
1022
1023    let (request2, respond_second) = requests.predict.next().await.unwrap();
1024
1025    // wait for throttle, so requests are sent
1026    cx.run_until_parked();
1027
1028    ep_store.update(cx, |ep_store, cx| {
1029        // start a third request
1030        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1031
1032        // 2 are pending, so 2nd is cancelled
1033        assert_eq!(
1034            ep_store
1035                .get_or_init_project(&project, cx)
1036                .cancelled_predictions
1037                .iter()
1038                .copied()
1039                .collect::<Vec<_>>(),
1040            [1]
1041        );
1042    });
1043
1044    // wait for throttle
1045    cx.run_until_parked();
1046
1047    let (request3, respond_third) = requests.predict.next().await.unwrap();
1048
1049    let first_response = model_response(request1, SIMPLE_DIFF);
1050    let first_id = first_response.id.clone();
1051    respond_first.send(first_response).unwrap();
1052
1053    cx.run_until_parked();
1054
1055    ep_store.update(cx, |ep_store, cx| {
1056        // current prediction is first
1057        assert_eq!(
1058            ep_store
1059                .prediction_at(&buffer, None, &project, cx)
1060                .unwrap()
1061                .id
1062                .0,
1063            first_id
1064        );
1065    });
1066
1067    let cancelled_response = model_response(request2, SIMPLE_DIFF);
1068    let cancelled_id = cancelled_response.id.clone();
1069    respond_second.send(cancelled_response).unwrap();
1070
1071    cx.run_until_parked();
1072
1073    ep_store.update(cx, |ep_store, cx| {
1074        // current prediction is still first, since second was cancelled
1075        assert_eq!(
1076            ep_store
1077                .prediction_at(&buffer, None, &project, cx)
1078                .unwrap()
1079                .id
1080                .0,
1081            first_id
1082        );
1083    });
1084
1085    let third_response = model_response(request3, SIMPLE_DIFF);
1086    let third_response_id = third_response.id.clone();
1087    respond_third.send(third_response).unwrap();
1088
1089    cx.run_until_parked();
1090
1091    ep_store.update(cx, |ep_store, cx| {
1092        // third completes and replaces first
1093        assert_eq!(
1094            ep_store
1095                .prediction_at(&buffer, None, &project, cx)
1096                .unwrap()
1097                .id
1098                .0,
1099            third_response_id
1100        );
1101    });
1102
1103    // second is reported as rejected
1104    let (reject_request, _) = requests.reject.next().await.unwrap();
1105
1106    cx.run_until_parked();
1107
1108    assert_eq!(
1109        &reject_request.rejections,
1110        &[
1111            EditPredictionRejection {
1112                request_id: cancelled_id,
1113                reason: EditPredictionRejectReason::Canceled,
1114                was_shown: false
1115            },
1116            EditPredictionRejection {
1117                request_id: first_id,
1118                reason: EditPredictionRejectReason::Replaced,
1119                was_shown: false
1120            }
1121        ]
1122    );
1123}
1124
1125#[gpui::test]
1126async fn test_rejections_flushing(cx: &mut TestAppContext) {
1127    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1128
1129    ep_store.update(cx, |ep_store, _cx| {
1130        ep_store.reject_prediction(
1131            EditPredictionId("test-1".into()),
1132            EditPredictionRejectReason::Discarded,
1133            false,
1134        );
1135        ep_store.reject_prediction(
1136            EditPredictionId("test-2".into()),
1137            EditPredictionRejectReason::Canceled,
1138            true,
1139        );
1140    });
1141
1142    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1143    cx.run_until_parked();
1144
1145    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1146    respond_tx.send(()).unwrap();
1147
1148    // batched
1149    assert_eq!(reject_request.rejections.len(), 2);
1150    assert_eq!(
1151        reject_request.rejections[0],
1152        EditPredictionRejection {
1153            request_id: "test-1".to_string(),
1154            reason: EditPredictionRejectReason::Discarded,
1155            was_shown: false
1156        }
1157    );
1158    assert_eq!(
1159        reject_request.rejections[1],
1160        EditPredictionRejection {
1161            request_id: "test-2".to_string(),
1162            reason: EditPredictionRejectReason::Canceled,
1163            was_shown: true
1164        }
1165    );
1166
1167    // Reaching batch size limit sends without debounce
1168    ep_store.update(cx, |ep_store, _cx| {
1169        for i in 0..70 {
1170            ep_store.reject_prediction(
1171                EditPredictionId(format!("batch-{}", i).into()),
1172                EditPredictionRejectReason::Discarded,
1173                false,
1174            );
1175        }
1176    });
1177
1178    // First MAX/2 items are sent immediately
1179    cx.run_until_parked();
1180    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1181    respond_tx.send(()).unwrap();
1182
1183    assert_eq!(reject_request.rejections.len(), 50);
1184    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1185    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1186
1187    // Remaining items are debounced with the next batch
1188    cx.executor().advance_clock(Duration::from_secs(15));
1189    cx.run_until_parked();
1190
1191    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1192    respond_tx.send(()).unwrap();
1193
1194    assert_eq!(reject_request.rejections.len(), 20);
1195    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1196    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1197
1198    // Request failure
1199    ep_store.update(cx, |ep_store, _cx| {
1200        ep_store.reject_prediction(
1201            EditPredictionId("retry-1".into()),
1202            EditPredictionRejectReason::Discarded,
1203            false,
1204        );
1205    });
1206
1207    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1208    cx.run_until_parked();
1209
1210    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1211    assert_eq!(reject_request.rejections.len(), 1);
1212    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1213    // Simulate failure
1214    drop(_respond_tx);
1215
1216    // Add another rejection
1217    ep_store.update(cx, |ep_store, _cx| {
1218        ep_store.reject_prediction(
1219            EditPredictionId("retry-2".into()),
1220            EditPredictionRejectReason::Discarded,
1221            false,
1222        );
1223    });
1224
1225    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1226    cx.run_until_parked();
1227
1228    // Retry should include both the failed item and the new one
1229    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1230    respond_tx.send(()).unwrap();
1231
1232    assert_eq!(reject_request.rejections.len(), 2);
1233    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1234    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1235}
1236
1237// Skipped until we start including diagnostics in prompt
1238// #[gpui::test]
1239// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1240//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1241//     let fs = FakeFs::new(cx.executor());
1242//     fs.insert_tree(
1243//         "/root",
1244//         json!({
1245//             "foo.md": "Hello!\nBye"
1246//         }),
1247//     )
1248//     .await;
1249//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1250
1251//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1252//     let diagnostic = lsp::Diagnostic {
1253//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1254//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1255//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1256//         ..Default::default()
1257//     };
1258
1259//     project.update(cx, |project, cx| {
1260//         project.lsp_store().update(cx, |lsp_store, cx| {
1261//             // Create some diagnostics
1262//             lsp_store
1263//                 .update_diagnostics(
1264//                     LanguageServerId(0),
1265//                     lsp::PublishDiagnosticsParams {
1266//                         uri: path_to_buffer_uri.clone(),
1267//                         diagnostics: vec![diagnostic],
1268//                         version: None,
1269//                     },
1270//                     None,
1271//                     language::DiagnosticSourceKind::Pushed,
1272//                     &[],
1273//                     cx,
1274//                 )
1275//                 .unwrap();
1276//         });
1277//     });
1278
1279//     let buffer = project
1280//         .update(cx, |project, cx| {
1281//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1282//             project.open_buffer(path, cx)
1283//         })
1284//         .await
1285//         .unwrap();
1286
1287//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1288//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1289
1290//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1291//         ep_store.request_prediction(&project, &buffer, position, cx)
1292//     });
1293
1294//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1295
1296//     assert_eq!(request.diagnostic_groups.len(), 1);
1297//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1298//         .unwrap();
1299//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1300//     assert_eq!(
1301//         value,
1302//         json!({
1303//             "entries": [{
1304//                 "range": {
1305//                     "start": 8,
1306//                     "end": 10
1307//                 },
1308//                 "diagnostic": {
1309//                     "source": null,
1310//                     "code": null,
1311//                     "code_description": null,
1312//                     "severity": 1,
1313//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1314//                     "markdown": null,
1315//                     "group_id": 0,
1316//                     "is_primary": true,
1317//                     "is_disk_based": false,
1318//                     "is_unnecessary": false,
1319//                     "source_kind": "Pushed",
1320//                     "data": null,
1321//                     "underline": true
1322//                 }
1323//             }],
1324//             "primary_ix": 0
1325//         })
1326//     );
1327// }
1328
1329// Generate a model response that would apply the given diff to the active file.
1330fn model_response(request: RawCompletionRequest, diff_to_apply: &str) -> RawCompletionResponse {
1331    let prompt = &request.prompt;
1332
1333    let current_marker = "<|fim_middle|>current\n";
1334    let updated_marker = "<|fim_middle|>updated\n";
1335    let cursor = "<|user_cursor|>";
1336
1337    let start_ix = current_marker.len() + prompt.find(current_marker).unwrap();
1338    let end_ix = start_ix + &prompt[start_ix..].find(updated_marker).unwrap();
1339    let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
1340    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1341
1342    RawCompletionResponse {
1343        id: Uuid::new_v4().to_string(),
1344        object: "text_completion".into(),
1345        created: 0,
1346        model: "model".into(),
1347        choices: vec![RawCompletionChoice {
1348            text: new_excerpt,
1349            finish_reason: None,
1350        }],
1351        usage: RawCompletionUsage {
1352            prompt_tokens: 0,
1353            completion_tokens: 0,
1354            total_tokens: 0,
1355        },
1356    }
1357}
1358
1359fn prompt_from_request(request: &RawCompletionRequest) -> &str {
1360    &request.prompt
1361}
1362
1363struct RequestChannels {
1364    predict:
1365        mpsc::UnboundedReceiver<(RawCompletionRequest, oneshot::Sender<RawCompletionResponse>)>,
1366    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1367}
1368
1369fn init_test_with_fake_client(
1370    cx: &mut TestAppContext,
1371) -> (Entity<EditPredictionStore>, RequestChannels) {
1372    cx.update(move |cx| {
1373        let settings_store = SettingsStore::test(cx);
1374        cx.set_global(settings_store);
1375        zlog::init_test();
1376
1377        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1378        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1379
1380        let http_client = FakeHttpClient::create({
1381            move |req| {
1382                let uri = req.uri().path().to_string();
1383                let mut body = req.into_body();
1384                let predict_req_tx = predict_req_tx.clone();
1385                let reject_req_tx = reject_req_tx.clone();
1386                async move {
1387                    let resp = match uri.as_str() {
1388                        "/client/llm_tokens" => serde_json::to_string(&json!({
1389                            "token": "test"
1390                        }))
1391                        .unwrap(),
1392                        "/predict_edits/raw" => {
1393                            let mut buf = Vec::new();
1394                            body.read_to_end(&mut buf).await.ok();
1395                            let req = serde_json::from_slice(&buf).unwrap();
1396
1397                            let (res_tx, res_rx) = oneshot::channel();
1398                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1399                            serde_json::to_string(&res_rx.await?).unwrap()
1400                        }
1401                        "/predict_edits/reject" => {
1402                            let mut buf = Vec::new();
1403                            body.read_to_end(&mut buf).await.ok();
1404                            let req = serde_json::from_slice(&buf).unwrap();
1405
1406                            let (res_tx, res_rx) = oneshot::channel();
1407                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1408                            serde_json::to_string(&res_rx.await?).unwrap()
1409                        }
1410                        _ => {
1411                            panic!("Unexpected path: {}", uri)
1412                        }
1413                    };
1414
1415                    Ok(Response::builder().body(resp.into()).unwrap())
1416                }
1417            }
1418        });
1419
1420        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1421        client.cloud_client().set_credentials(1, "test".into());
1422
1423        language_model::init(client.clone(), cx);
1424
1425        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1426        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1427
1428        (
1429            ep_store,
1430            RequestChannels {
1431                predict: predict_req_rx,
1432                reject: reject_req_rx,
1433            },
1434        )
1435    })
1436}
1437
1438const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1439
1440#[gpui::test]
1441async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1442    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1443    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1444        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1445    });
1446
1447    let edit_preview = cx
1448        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1449        .await;
1450
1451    let prediction = EditPrediction {
1452        edits,
1453        edit_preview,
1454        buffer: buffer.clone(),
1455        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1456        id: EditPredictionId("the-id".into()),
1457        inputs: ZetaPromptInput {
1458            events: Default::default(),
1459            related_files: Default::default(),
1460            cursor_path: Path::new("").into(),
1461            cursor_excerpt: "".into(),
1462            editable_range_in_excerpt: 0..0,
1463            cursor_offset_in_excerpt: 0,
1464        },
1465        buffer_snapshotted_at: Instant::now(),
1466        response_received_at: Instant::now(),
1467    };
1468
1469    cx.update(|cx| {
1470        assert_eq!(
1471            from_completion_edits(
1472                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1473                &buffer,
1474                cx
1475            ),
1476            vec![(2..5, "REM".into()), (9..11, "".into())]
1477        );
1478
1479        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1480        assert_eq!(
1481            from_completion_edits(
1482                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1483                &buffer,
1484                cx
1485            ),
1486            vec![(2..2, "REM".into()), (6..8, "".into())]
1487        );
1488
1489        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1490        assert_eq!(
1491            from_completion_edits(
1492                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1493                &buffer,
1494                cx
1495            ),
1496            vec![(2..5, "REM".into()), (9..11, "".into())]
1497        );
1498
1499        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1500        assert_eq!(
1501            from_completion_edits(
1502                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1503                &buffer,
1504                cx
1505            ),
1506            vec![(3..3, "EM".into()), (7..9, "".into())]
1507        );
1508
1509        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1510        assert_eq!(
1511            from_completion_edits(
1512                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1513                &buffer,
1514                cx
1515            ),
1516            vec![(4..4, "M".into()), (8..10, "".into())]
1517        );
1518
1519        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1520        assert_eq!(
1521            from_completion_edits(
1522                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1523                &buffer,
1524                cx
1525            ),
1526            vec![(9..11, "".into())]
1527        );
1528
1529        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1530        assert_eq!(
1531            from_completion_edits(
1532                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1533                &buffer,
1534                cx
1535            ),
1536            vec![(4..4, "M".into()), (8..10, "".into())]
1537        );
1538
1539        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1540        assert_eq!(
1541            from_completion_edits(
1542                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1543                &buffer,
1544                cx
1545            ),
1546            vec![(4..4, "M".into())]
1547        );
1548
1549        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1550        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1551    })
1552}
1553
1554#[gpui::test]
1555async fn test_clean_up_diff(cx: &mut TestAppContext) {
1556    init_test(cx);
1557
1558    assert_eq!(
1559        apply_edit_prediction(
1560            indoc! {"
1561                    fn main() {
1562                        let word_1 = \"lorem\";
1563                        let range = word.len()..word.len();
1564                    }
1565                "},
1566            indoc! {"
1567                    <|editable_region_start|>
1568                    fn main() {
1569                        let word_1 = \"lorem\";
1570                        let range = word_1.len()..word_1.len();
1571                    }
1572
1573                    <|editable_region_end|>
1574                "},
1575            cx,
1576        )
1577        .await,
1578        indoc! {"
1579                fn main() {
1580                    let word_1 = \"lorem\";
1581                    let range = word_1.len()..word_1.len();
1582                }
1583            "},
1584    );
1585
1586    assert_eq!(
1587        apply_edit_prediction(
1588            indoc! {"
1589                    fn main() {
1590                        let story = \"the quick\"
1591                    }
1592                "},
1593            indoc! {"
1594                    <|editable_region_start|>
1595                    fn main() {
1596                        let story = \"the quick brown fox jumps over the lazy dog\";
1597                    }
1598
1599                    <|editable_region_end|>
1600                "},
1601            cx,
1602        )
1603        .await,
1604        indoc! {"
1605                fn main() {
1606                    let story = \"the quick brown fox jumps over the lazy dog\";
1607                }
1608            "},
1609    );
1610}
1611
1612#[gpui::test]
1613async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1614    init_test(cx);
1615
1616    let buffer_content = "lorem\n";
1617    let completion_response = indoc! {"
1618            ```animals.js
1619            <|start_of_file|>
1620            <|editable_region_start|>
1621            lorem
1622            ipsum
1623            <|editable_region_end|>
1624            ```"};
1625
1626    assert_eq!(
1627        apply_edit_prediction(buffer_content, completion_response, cx).await,
1628        "lorem\nipsum"
1629    );
1630}
1631
1632#[gpui::test]
1633async fn test_can_collect_data(cx: &mut TestAppContext) {
1634    init_test(cx);
1635
1636    let fs = project::FakeFs::new(cx.executor());
1637    fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1638        .await;
1639
1640    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1641    let buffer = project
1642        .update(cx, |project, cx| {
1643            project.open_local_buffer(path!("/project/src/main.rs"), cx)
1644        })
1645        .await
1646        .unwrap();
1647
1648    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1649    ep_store.update(cx, |ep_store, _cx| {
1650        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1651    });
1652
1653    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1654    assert_eq!(
1655        captured_request.lock().clone().unwrap().can_collect_data,
1656        true
1657    );
1658
1659    ep_store.update(cx, |ep_store, _cx| {
1660        ep_store.data_collection_choice = DataCollectionChoice::Disabled
1661    });
1662
1663    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1664    assert_eq!(
1665        captured_request.lock().clone().unwrap().can_collect_data,
1666        false
1667    );
1668}
1669
1670#[gpui::test]
1671async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1672    init_test(cx);
1673
1674    let fs = project::FakeFs::new(cx.executor());
1675    let project = Project::test(fs.clone(), [], cx).await;
1676
1677    let buffer = cx.new(|_cx| {
1678        Buffer::remote(
1679            language::BufferId::new(1).unwrap(),
1680            ReplicaId::new(1),
1681            language::Capability::ReadWrite,
1682            "fn main() {\n    println!(\"Hello\");\n}",
1683        )
1684    });
1685
1686    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1687    ep_store.update(cx, |ep_store, _cx| {
1688        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1689    });
1690
1691    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1692    assert_eq!(
1693        captured_request.lock().clone().unwrap().can_collect_data,
1694        false
1695    );
1696}
1697
1698#[gpui::test]
1699async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1700    init_test(cx);
1701
1702    let fs = project::FakeFs::new(cx.executor());
1703    fs.insert_tree(
1704        path!("/project"),
1705        json!({
1706            "LICENSE": BSD_0_TXT,
1707            ".env": "SECRET_KEY=secret"
1708        }),
1709    )
1710    .await;
1711
1712    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1713    let buffer = project
1714        .update(cx, |project, cx| {
1715            project.open_local_buffer("/project/.env", cx)
1716        })
1717        .await
1718        .unwrap();
1719
1720    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1721    ep_store.update(cx, |ep_store, _cx| {
1722        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1723    });
1724
1725    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1726    assert_eq!(
1727        captured_request.lock().clone().unwrap().can_collect_data,
1728        false
1729    );
1730}
1731
1732#[gpui::test]
1733async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1734    init_test(cx);
1735
1736    let fs = project::FakeFs::new(cx.executor());
1737    let project = Project::test(fs.clone(), [], cx).await;
1738    let buffer = cx.new(|cx| Buffer::local("", cx));
1739
1740    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1741    ep_store.update(cx, |ep_store, _cx| {
1742        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1743    });
1744
1745    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1746    assert_eq!(
1747        captured_request.lock().clone().unwrap().can_collect_data,
1748        false
1749    );
1750}
1751
1752#[gpui::test]
1753async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1754    init_test(cx);
1755
1756    let fs = project::FakeFs::new(cx.executor());
1757    fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1758        .await;
1759
1760    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1761    let buffer = project
1762        .update(cx, |project, cx| {
1763            project.open_local_buffer("/project/main.rs", cx)
1764        })
1765        .await
1766        .unwrap();
1767
1768    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1769    ep_store.update(cx, |ep_store, _cx| {
1770        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1771    });
1772
1773    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1774    assert_eq!(
1775        captured_request.lock().clone().unwrap().can_collect_data,
1776        false
1777    );
1778}
1779
1780#[gpui::test]
1781async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1782    init_test(cx);
1783
1784    let fs = project::FakeFs::new(cx.executor());
1785    fs.insert_tree(
1786        path!("/open_source_worktree"),
1787        json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1788    )
1789    .await;
1790    fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1791        .await;
1792
1793    let project = Project::test(
1794        fs.clone(),
1795        [
1796            path!("/open_source_worktree").as_ref(),
1797            path!("/closed_source_worktree").as_ref(),
1798        ],
1799        cx,
1800    )
1801    .await;
1802    let buffer = project
1803        .update(cx, |project, cx| {
1804            project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1805        })
1806        .await
1807        .unwrap();
1808
1809    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1810    ep_store.update(cx, |ep_store, _cx| {
1811        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1812    });
1813
1814    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1815    assert_eq!(
1816        captured_request.lock().clone().unwrap().can_collect_data,
1817        true
1818    );
1819
1820    let closed_source_file = project
1821        .update(cx, |project, cx| {
1822            let worktree2 = project
1823                .worktree_for_root_name("closed_source_worktree", cx)
1824                .unwrap();
1825            worktree2.update(cx, |worktree2, cx| {
1826                worktree2.load_file(rel_path("main.rs"), cx)
1827            })
1828        })
1829        .await
1830        .unwrap()
1831        .file;
1832
1833    buffer.update(cx, |buffer, cx| {
1834        buffer.file_updated(closed_source_file, cx);
1835    });
1836
1837    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1838    assert_eq!(
1839        captured_request.lock().clone().unwrap().can_collect_data,
1840        false
1841    );
1842}
1843
1844#[gpui::test]
1845async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1846    init_test(cx);
1847
1848    let fs = project::FakeFs::new(cx.executor());
1849    fs.insert_tree(
1850        path!("/worktree1"),
1851        json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1852    )
1853    .await;
1854    fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1855        .await;
1856
1857    let project = Project::test(
1858        fs.clone(),
1859        [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1860        cx,
1861    )
1862    .await;
1863    let buffer = project
1864        .update(cx, |project, cx| {
1865            project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1866        })
1867        .await
1868        .unwrap();
1869    let private_buffer = project
1870        .update(cx, |project, cx| {
1871            project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1872        })
1873        .await
1874        .unwrap();
1875
1876    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1877    ep_store.update(cx, |ep_store, _cx| {
1878        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1879    });
1880
1881    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1882    assert_eq!(
1883        captured_request.lock().clone().unwrap().can_collect_data,
1884        true
1885    );
1886
1887    // this has a side effect of registering the buffer to watch for edits
1888    run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1889    assert_eq!(
1890        captured_request.lock().clone().unwrap().can_collect_data,
1891        false
1892    );
1893
1894    private_buffer.update(cx, |private_buffer, cx| {
1895        private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1896    });
1897
1898    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1899    assert_eq!(
1900        captured_request.lock().clone().unwrap().can_collect_data,
1901        false
1902    );
1903
1904    // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1905    // included
1906    buffer.update(cx, |buffer, cx| {
1907        buffer.edit(
1908            [(
1909                0..0,
1910                " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1911            )],
1912            None,
1913            cx,
1914        );
1915    });
1916
1917    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1918    assert_eq!(
1919        captured_request.lock().clone().unwrap().can_collect_data,
1920        true
1921    );
1922}
1923
1924fn init_test(cx: &mut TestAppContext) {
1925    cx.update(|cx| {
1926        let settings_store = SettingsStore::test(cx);
1927        cx.set_global(settings_store);
1928    });
1929}
1930
1931async fn apply_edit_prediction(
1932    buffer_content: &str,
1933    completion_response: &str,
1934    cx: &mut TestAppContext,
1935) -> String {
1936    let fs = project::FakeFs::new(cx.executor());
1937    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1938    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1939    let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1940    *response.lock() = completion_response.to_string();
1941    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1942    buffer.update(cx, |buffer, cx| {
1943        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1944    });
1945    buffer.read_with(cx, |buffer, _| buffer.text())
1946}
1947
1948async fn run_edit_prediction(
1949    buffer: &Entity<Buffer>,
1950    project: &Entity<Project>,
1951    ep_store: &Entity<EditPredictionStore>,
1952    cx: &mut TestAppContext,
1953) -> EditPrediction {
1954    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1955    ep_store.update(cx, |ep_store, cx| {
1956        ep_store.register_buffer(buffer, &project, cx)
1957    });
1958    cx.background_executor.run_until_parked();
1959    let prediction_task = ep_store.update(cx, |ep_store, cx| {
1960        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
1961    });
1962    prediction_task.await.unwrap().unwrap().prediction.unwrap()
1963}
1964
1965async fn make_test_ep_store(
1966    project: &Entity<Project>,
1967    cx: &mut TestAppContext,
1968) -> (
1969    Entity<EditPredictionStore>,
1970    Arc<Mutex<Option<PredictEditsBody>>>,
1971    Arc<Mutex<String>>,
1972) {
1973    let default_response = indoc! {"
1974            ```main.rs
1975            <|start_of_file|>
1976            <|editable_region_start|>
1977            hello world
1978            <|editable_region_end|>
1979            ```"
1980    };
1981    let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
1982    let completion_response: Arc<Mutex<String>> =
1983        Arc::new(Mutex::new(default_response.to_string()));
1984    let http_client = FakeHttpClient::create({
1985        let captured_request = captured_request.clone();
1986        let completion_response = completion_response.clone();
1987        let mut next_request_id = 0;
1988        move |req| {
1989            let captured_request = captured_request.clone();
1990            let completion_response = completion_response.clone();
1991            async move {
1992                match (req.method(), req.uri().path()) {
1993                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
1994                        .status(200)
1995                        .body(
1996                            serde_json::to_string(&CreateLlmTokenResponse {
1997                                token: LlmToken("the-llm-token".to_string()),
1998                            })
1999                            .unwrap()
2000                            .into(),
2001                        )
2002                        .unwrap()),
2003                    (&Method::POST, "/predict_edits/v2") => {
2004                        let mut request_body = String::new();
2005                        req.into_body().read_to_string(&mut request_body).await?;
2006                        *captured_request.lock() =
2007                            Some(serde_json::from_str(&request_body).unwrap());
2008                        next_request_id += 1;
2009                        Ok(http_client::Response::builder()
2010                            .status(200)
2011                            .body(
2012                                serde_json::to_string(&PredictEditsResponse {
2013                                    request_id: format!("request-{next_request_id}"),
2014                                    output_excerpt: completion_response.lock().clone(),
2015                                })
2016                                .unwrap()
2017                                .into(),
2018                            )
2019                            .unwrap())
2020                    }
2021                    _ => Ok(http_client::Response::builder()
2022                        .status(404)
2023                        .body("Not Found".into())
2024                        .unwrap()),
2025                }
2026            }
2027        }
2028    });
2029
2030    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2031    cx.update(|cx| {
2032        RefreshLlmTokenListener::register(client.clone(), cx);
2033    });
2034    let _server = FakeServer::for_client(42, &client, cx).await;
2035
2036    let ep_store = cx.new(|cx| {
2037        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2038        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2039
2040        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2041        for worktree in worktrees {
2042            let worktree_id = worktree.read(cx).id();
2043            ep_store
2044                .get_or_init_project(project, cx)
2045                .license_detection_watchers
2046                .entry(worktree_id)
2047                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2048        }
2049
2050        ep_store
2051    });
2052
2053    (ep_store, captured_request, completion_response)
2054}
2055
2056fn to_completion_edits(
2057    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2058    buffer: &Entity<Buffer>,
2059    cx: &App,
2060) -> Vec<(Range<Anchor>, Arc<str>)> {
2061    let buffer = buffer.read(cx);
2062    iterator
2063        .into_iter()
2064        .map(|(range, text)| {
2065            (
2066                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2067                text,
2068            )
2069        })
2070        .collect()
2071}
2072
2073fn from_completion_edits(
2074    editor_edits: &[(Range<Anchor>, Arc<str>)],
2075    buffer: &Entity<Buffer>,
2076    cx: &App,
2077) -> Vec<(Range<usize>, Arc<str>)> {
2078    let buffer = buffer.read(cx);
2079    editor_edits
2080        .iter()
2081        .map(|(range, text)| {
2082            (
2083                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2084                text.clone(),
2085            )
2086        })
2087        .collect()
2088}
2089
2090#[gpui::test]
2091async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2092    init_test(cx);
2093
2094    let fs = FakeFs::new(cx.executor());
2095    fs.insert_tree(
2096        "/project",
2097        serde_json::json!({
2098            "main.rs": "fn main() {\n    \n}\n"
2099        }),
2100    )
2101    .await;
2102
2103    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2104
2105    let http_client = FakeHttpClient::create(|_req| async move {
2106        Ok(gpui::http_client::Response::builder()
2107            .status(401)
2108            .body("Unauthorized".into())
2109            .unwrap())
2110    });
2111
2112    let client =
2113        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2114    cx.update(|cx| {
2115        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2116    });
2117
2118    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2119
2120    let buffer = project
2121        .update(cx, |project, cx| {
2122            let path = project
2123                .find_project_path(path!("/project/main.rs"), cx)
2124                .unwrap();
2125            project.open_buffer(path, cx)
2126        })
2127        .await
2128        .unwrap();
2129
2130    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2131    ep_store.update(cx, |ep_store, cx| {
2132        ep_store.register_buffer(&buffer, &project, cx)
2133    });
2134    cx.background_executor.run_until_parked();
2135
2136    let completion_task = ep_store.update(cx, |ep_store, cx| {
2137        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2138        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2139    });
2140
2141    let result = completion_task.await;
2142    assert!(
2143        result.is_err(),
2144        "Without authentication and without custom URL, prediction should fail"
2145    );
2146}
2147
2148#[gpui::test]
2149async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2150    init_test(cx);
2151
2152    let fs = FakeFs::new(cx.executor());
2153    fs.insert_tree(
2154        "/project",
2155        serde_json::json!({
2156            "main.rs": "fn main() {\n    \n}\n"
2157        }),
2158    )
2159    .await;
2160
2161    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2162
2163    let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2164    let predict_called_clone = predict_called.clone();
2165
2166    let http_client = FakeHttpClient::create({
2167        move |req| {
2168            let uri = req.uri().path().to_string();
2169            let predict_called = predict_called_clone.clone();
2170            async move {
2171                if uri.contains("predict") {
2172                    predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2173                    Ok(gpui::http_client::Response::builder()
2174                        .body(
2175                            serde_json::to_string(&open_ai::Response {
2176                                id: "test-123".to_string(),
2177                                object: "chat.completion".to_string(),
2178                                created: 0,
2179                                model: "test".to_string(),
2180                                usage: open_ai::Usage {
2181                                    prompt_tokens: 0,
2182                                    completion_tokens: 0,
2183                                    total_tokens: 0,
2184                                },
2185                                choices: vec![open_ai::Choice {
2186                                    index: 0,
2187                                    message: open_ai::RequestMessage::Assistant {
2188                                        content: Some(open_ai::MessageContent::Plain(
2189                                            indoc! {"
2190                                                ```main.rs
2191                                                <|start_of_file|>
2192                                                <|editable_region_start|>
2193                                                fn main() {
2194                                                    println!(\"Hello, world!\");
2195                                                }
2196                                                <|editable_region_end|>
2197                                                ```
2198                                            "}
2199                                            .to_string(),
2200                                        )),
2201                                        tool_calls: vec![],
2202                                    },
2203                                    finish_reason: Some("stop".to_string()),
2204                                }],
2205                            })
2206                            .unwrap()
2207                            .into(),
2208                        )
2209                        .unwrap())
2210                } else {
2211                    Ok(gpui::http_client::Response::builder()
2212                        .status(401)
2213                        .body("Unauthorized".into())
2214                        .unwrap())
2215                }
2216            }
2217        }
2218    });
2219
2220    let client =
2221        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2222    cx.update(|cx| {
2223        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2224    });
2225
2226    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2227
2228    let buffer = project
2229        .update(cx, |project, cx| {
2230            let path = project
2231                .find_project_path(path!("/project/main.rs"), cx)
2232                .unwrap();
2233            project.open_buffer(path, cx)
2234        })
2235        .await
2236        .unwrap();
2237
2238    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2239    ep_store.update(cx, |ep_store, cx| {
2240        ep_store.register_buffer(&buffer, &project, cx)
2241    });
2242    cx.background_executor.run_until_parked();
2243
2244    let completion_task = ep_store.update(cx, |ep_store, cx| {
2245        ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2246        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2247        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2248    });
2249
2250    let _ = completion_task.await;
2251
2252    assert!(
2253        predict_called.load(std::sync::atomic::Ordering::SeqCst),
2254        "With custom URL, predict endpoint should be called even without authentication"
2255    );
2256}
2257
2258#[gpui::test]
2259fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2260    let buffer = cx.new(|cx| {
2261        Buffer::local(
2262            indoc! {"
2263                zero
2264                one
2265                two
2266                three
2267                four
2268                five
2269                six
2270                seven
2271                eight
2272                nine
2273                ten
2274                eleven
2275                twelve
2276                thirteen
2277                fourteen
2278                fifteen
2279                sixteen
2280                seventeen
2281                eighteen
2282                nineteen
2283                twenty
2284                twenty-one
2285                twenty-two
2286                twenty-three
2287                twenty-four
2288            "},
2289            cx,
2290        )
2291    });
2292
2293    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2294
2295    buffer.update(cx, |buffer, cx| {
2296        let point = Point::new(12, 0);
2297        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2298        let point = Point::new(8, 0);
2299        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2300    });
2301
2302    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2303
2304    let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2305
2306    assert_eq!(
2307        diff,
2308        indoc! {"
2309            @@ -6,10 +6,12 @@
2310             five
2311             six
2312             seven
2313            +FIRST INSERTION
2314             eight
2315             nine
2316             ten
2317             eleven
2318            +SECOND INSERTION
2319             twelve
2320             thirteen
2321             fourteen
2322            "}
2323    );
2324}
2325
2326#[ctor::ctor]
2327fn init_logger() {
2328    zlog::init_test();
2329}