edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
   3use client::{UserStore, test::FakeServer};
   4use clock::{FakeSystemClock, ReplicaId};
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
   8    RejectEditPredictionsBody,
   9    predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
  10};
  11use futures::{
  12    AsyncReadExt, StreamExt,
  13    channel::{mpsc, oneshot},
  14};
  15use gpui::App;
  16use gpui::{
  17    Entity, TestAppContext,
  18    http_client::{FakeHttpClient, Response},
  19};
  20use indoc::indoc;
  21use language::{Buffer, Point};
  22use lsp::LanguageServerId;
  23use parking_lot::Mutex;
  24use pretty_assertions::{assert_eq, assert_matches};
  25use project::{FakeFs, Project};
  26use serde_json::json;
  27use settings::SettingsStore;
  28use std::{path::Path, sync::Arc, time::Duration};
  29use util::{path, rel_path::rel_path};
  30use uuid::Uuid;
  31use zeta_prompt::ZetaPromptInput;
  32
  33use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  34
  35#[gpui::test]
  36async fn test_current_state(cx: &mut TestAppContext) {
  37    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  38    let fs = FakeFs::new(cx.executor());
  39    fs.insert_tree(
  40        "/root",
  41        json!({
  42            "1.txt": "Hello!\nHow\nBye\n",
  43            "2.txt": "Hola!\nComo\nAdios\n"
  44        }),
  45    )
  46    .await;
  47    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  48
  49    let buffer1 = project
  50        .update(cx, |project, cx| {
  51            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  52            project.set_active_path(Some(path.clone()), cx);
  53            project.open_buffer(path, cx)
  54        })
  55        .await
  56        .unwrap();
  57    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  58    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  59
  60    ep_store.update(cx, |ep_store, cx| {
  61        ep_store.register_project(&project, cx);
  62        ep_store.register_buffer(&buffer1, &project, cx);
  63    });
  64
  65    // Prediction for current file
  66
  67    ep_store.update(cx, |ep_store, cx| {
  68        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  69    });
  70    let (request, respond_tx) = requests.predict.next().await.unwrap();
  71
  72    respond_tx
  73        .send(model_response(
  74            &request,
  75            indoc! {r"
  76                --- a/root/1.txt
  77                +++ b/root/1.txt
  78                @@ ... @@
  79                 Hello!
  80                -How
  81                +How are you?
  82                 Bye
  83            "},
  84        ))
  85        .unwrap();
  86
  87    cx.run_until_parked();
  88
  89    ep_store.update(cx, |ep_store, cx| {
  90        let prediction = ep_store
  91            .prediction_at(&buffer1, None, &project, cx)
  92            .unwrap();
  93        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  94    });
  95
  96    ep_store.update(cx, |ep_store, cx| {
  97        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
  98    });
  99
 100    // Prediction for diagnostic in another file
 101
 102    let diagnostic = lsp::Diagnostic {
 103        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 104        severity: Some(lsp::DiagnosticSeverity::ERROR),
 105        message: "Sentence is incomplete".to_string(),
 106        ..Default::default()
 107    };
 108
 109    project.update(cx, |project, cx| {
 110        project.lsp_store().update(cx, |lsp_store, cx| {
 111            lsp_store
 112                .update_diagnostics(
 113                    LanguageServerId(0),
 114                    lsp::PublishDiagnosticsParams {
 115                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 116                        diagnostics: vec![diagnostic],
 117                        version: None,
 118                    },
 119                    None,
 120                    language::DiagnosticSourceKind::Pushed,
 121                    &[],
 122                    cx,
 123                )
 124                .unwrap();
 125        });
 126    });
 127
 128    let (request, respond_tx) = requests.predict.next().await.unwrap();
 129    respond_tx
 130        .send(model_response(
 131            &request,
 132            indoc! {r#"
 133                --- a/root/2.txt
 134                +++ b/root/2.txt
 135                @@ ... @@
 136                 Hola!
 137                -Como
 138                +Como estas?
 139                 Adios
 140            "#},
 141        ))
 142        .unwrap();
 143    cx.run_until_parked();
 144
 145    ep_store.update(cx, |ep_store, cx| {
 146        let prediction = ep_store
 147            .prediction_at(&buffer1, None, &project, cx)
 148            .unwrap();
 149        assert_matches!(
 150            prediction,
 151            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 152        );
 153    });
 154
 155    let buffer2 = project
 156        .update(cx, |project, cx| {
 157            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 158            project.open_buffer(path, cx)
 159        })
 160        .await
 161        .unwrap();
 162
 163    ep_store.update(cx, |ep_store, cx| {
 164        let prediction = ep_store
 165            .prediction_at(&buffer2, None, &project, cx)
 166            .unwrap();
 167        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 168    });
 169}
 170
 171#[gpui::test]
 172async fn test_simple_request(cx: &mut TestAppContext) {
 173    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 174    let fs = FakeFs::new(cx.executor());
 175    fs.insert_tree(
 176        "/root",
 177        json!({
 178            "foo.md":  "Hello!\nHow\nBye\n"
 179        }),
 180    )
 181    .await;
 182    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 183
 184    let buffer = project
 185        .update(cx, |project, cx| {
 186            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 187            project.open_buffer(path, cx)
 188        })
 189        .await
 190        .unwrap();
 191    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 192    let position = snapshot.anchor_before(language::Point::new(1, 3));
 193
 194    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 195        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 196    });
 197
 198    let (request, respond_tx) = requests.predict.next().await.unwrap();
 199
 200    // TODO Put back when we have a structured request again
 201    // assert_eq!(
 202    //     request.excerpt_path.as_ref(),
 203    //     Path::new(path!("root/foo.md"))
 204    // );
 205    // assert_eq!(
 206    //     request.cursor_point,
 207    //     Point {
 208    //         line: Line(1),
 209    //         column: 3
 210    //     }
 211    // );
 212
 213    respond_tx
 214        .send(model_response(
 215            &request,
 216            indoc! { r"
 217                --- a/root/foo.md
 218                +++ b/root/foo.md
 219                @@ ... @@
 220                 Hello!
 221                -How
 222                +How are you?
 223                 Bye
 224            "},
 225        ))
 226        .unwrap();
 227
 228    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 229
 230    assert_eq!(prediction.edits.len(), 1);
 231    assert_eq!(
 232        prediction.edits[0].0.to_point(&snapshot).start,
 233        language::Point::new(1, 3)
 234    );
 235    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 236}
 237
 238#[gpui::test]
 239async fn test_request_events(cx: &mut TestAppContext) {
 240    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 241    let fs = FakeFs::new(cx.executor());
 242    fs.insert_tree(
 243        "/root",
 244        json!({
 245            "foo.md": "Hello!\n\nBye\n"
 246        }),
 247    )
 248    .await;
 249    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 250
 251    let buffer = project
 252        .update(cx, |project, cx| {
 253            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 254            project.open_buffer(path, cx)
 255        })
 256        .await
 257        .unwrap();
 258
 259    ep_store.update(cx, |ep_store, cx| {
 260        ep_store.register_buffer(&buffer, &project, cx);
 261    });
 262
 263    buffer.update(cx, |buffer, cx| {
 264        buffer.edit(vec![(7..7, "How")], None, cx);
 265    });
 266
 267    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 268    let position = snapshot.anchor_before(language::Point::new(1, 3));
 269
 270    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 271        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 272    });
 273
 274    let (request, respond_tx) = requests.predict.next().await.unwrap();
 275
 276    let prompt = prompt_from_request(&request);
 277    assert!(
 278        prompt.contains(indoc! {"
 279        --- a/root/foo.md
 280        +++ b/root/foo.md
 281        @@ -1,3 +1,3 @@
 282         Hello!
 283        -
 284        +How
 285         Bye
 286    "}),
 287        "{prompt}"
 288    );
 289
 290    respond_tx
 291        .send(model_response(
 292            &request,
 293            indoc! {r#"
 294                --- a/root/foo.md
 295                +++ b/root/foo.md
 296                @@ ... @@
 297                 Hello!
 298                -How
 299                +How are you?
 300                 Bye
 301        "#},
 302        ))
 303        .unwrap();
 304
 305    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 306
 307    assert_eq!(prediction.edits.len(), 1);
 308    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 309}
 310
 311#[gpui::test]
 312async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 313    let (ep_store, _requests) = init_test_with_fake_client(cx);
 314    let fs = FakeFs::new(cx.executor());
 315    fs.insert_tree(
 316        "/root",
 317        json!({
 318            "foo.md": "Hello!\n\nBye\n"
 319        }),
 320    )
 321    .await;
 322    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 323
 324    let buffer = project
 325        .update(cx, |project, cx| {
 326            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 327            project.open_buffer(path, cx)
 328        })
 329        .await
 330        .unwrap();
 331
 332    ep_store.update(cx, |ep_store, cx| {
 333        ep_store.register_buffer(&buffer, &project, cx);
 334    });
 335
 336    // First burst: insert "How"
 337    buffer.update(cx, |buffer, cx| {
 338        buffer.edit(vec![(7..7, "How")], None, cx);
 339    });
 340
 341    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 342    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 343    cx.run_until_parked();
 344
 345    // Second burst: append " are you?" immediately after "How" on the same line.
 346    //
 347    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 348    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 349    buffer.update(cx, |buffer, cx| {
 350        buffer.edit(vec![(10..10, " are you?")], None, cx);
 351    });
 352
 353    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 354    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 355    buffer.update(cx, |buffer, cx| {
 356        buffer.edit(vec![(19..19, "!")], None, cx);
 357    });
 358
 359    // Without time-based splitting, there is one event.
 360    let events = ep_store.update(cx, |ep_store, cx| {
 361        ep_store.edit_history_for_project(&project, cx)
 362    });
 363    assert_eq!(events.len(), 1);
 364    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 365    assert_eq!(
 366        diff.as_str(),
 367        indoc! {"
 368            @@ -1,3 +1,3 @@
 369             Hello!
 370            -
 371            +How are you?!
 372             Bye
 373        "}
 374    );
 375
 376    // With time-based splitting, there are two distinct events.
 377    let events = ep_store.update(cx, |ep_store, cx| {
 378        ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
 379    });
 380    assert_eq!(events.len(), 2);
 381    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 382    assert_eq!(
 383        diff.as_str(),
 384        indoc! {"
 385            @@ -1,3 +1,3 @@
 386             Hello!
 387            -
 388            +How
 389             Bye
 390        "}
 391    );
 392
 393    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 394    assert_eq!(
 395        diff.as_str(),
 396        indoc! {"
 397            @@ -1,3 +1,3 @@
 398             Hello!
 399            -How
 400            +How are you?!
 401             Bye
 402        "}
 403    );
 404}
 405
 406#[gpui::test]
 407async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
 408    let (ep_store, _requests) = init_test_with_fake_client(cx);
 409    let fs = FakeFs::new(cx.executor());
 410
 411    // Create a file with 30 lines to test line-based coalescing
 412    let content = (1..=30)
 413        .map(|i| format!("Line {}\n", i))
 414        .collect::<String>();
 415    fs.insert_tree(
 416        "/root",
 417        json!({
 418            "foo.md": content
 419        }),
 420    )
 421    .await;
 422    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 423
 424    let buffer = project
 425        .update(cx, |project, cx| {
 426            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 427            project.open_buffer(path, cx)
 428        })
 429        .await
 430        .unwrap();
 431
 432    ep_store.update(cx, |ep_store, cx| {
 433        ep_store.register_buffer(&buffer, &project, cx);
 434    });
 435
 436    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 437    buffer.update(cx, |buffer, cx| {
 438        let start = Point::new(10, 0).to_offset(buffer);
 439        let end = Point::new(13, 0).to_offset(buffer);
 440        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 441    });
 442
 443    let events = ep_store.update(cx, |ep_store, cx| {
 444        ep_store.edit_history_for_project(&project, cx)
 445    });
 446    assert_eq!(
 447        render_events(&events),
 448        indoc! {"
 449            @@ -8,9 +8,8 @@
 450             Line 8
 451             Line 9
 452             Line 10
 453            -Line 11
 454            -Line 12
 455            -Line 13
 456            +Middle A
 457            +Middle B
 458             Line 14
 459             Line 15
 460             Line 16
 461        "},
 462        "After first edit"
 463    );
 464
 465    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 466    // This tests that coalescing considers the START of the existing range
 467    buffer.update(cx, |buffer, cx| {
 468        let offset = Point::new(5, 0).to_offset(buffer);
 469        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 470    });
 471
 472    let events = ep_store.update(cx, |ep_store, cx| {
 473        ep_store.edit_history_for_project(&project, cx)
 474    });
 475    assert_eq!(
 476        render_events(&events),
 477        indoc! {"
 478            @@ -3,14 +3,14 @@
 479             Line 3
 480             Line 4
 481             Line 5
 482            +Above
 483             Line 6
 484             Line 7
 485             Line 8
 486             Line 9
 487             Line 10
 488            -Line 11
 489            -Line 12
 490            -Line 13
 491            +Middle A
 492            +Middle B
 493             Line 14
 494             Line 15
 495             Line 16
 496        "},
 497        "After inserting above (should coalesce)"
 498    );
 499
 500    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 501    // This tests that coalescing considers the END of the existing range
 502    buffer.update(cx, |buffer, cx| {
 503        let offset = Point::new(14, 0).to_offset(buffer);
 504        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 505    });
 506
 507    let events = ep_store.update(cx, |ep_store, cx| {
 508        ep_store.edit_history_for_project(&project, cx)
 509    });
 510    assert_eq!(
 511        render_events(&events),
 512        indoc! {"
 513            @@ -3,15 +3,16 @@
 514             Line 3
 515             Line 4
 516             Line 5
 517            +Above
 518             Line 6
 519             Line 7
 520             Line 8
 521             Line 9
 522             Line 10
 523            -Line 11
 524            -Line 12
 525            -Line 13
 526            +Middle A
 527            +Middle B
 528             Line 14
 529            +Below
 530             Line 15
 531             Line 16
 532             Line 17
 533        "},
 534        "After inserting below (should coalesce)"
 535    );
 536
 537    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 538    // This should NOT coalesce - creates a new event
 539    buffer.update(cx, |buffer, cx| {
 540        let offset = Point::new(25, 0).to_offset(buffer);
 541        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 542    });
 543
 544    let events = ep_store.update(cx, |ep_store, cx| {
 545        ep_store.edit_history_for_project(&project, cx)
 546    });
 547    assert_eq!(
 548        render_events(&events),
 549        indoc! {"
 550            @@ -3,15 +3,16 @@
 551             Line 3
 552             Line 4
 553             Line 5
 554            +Above
 555             Line 6
 556             Line 7
 557             Line 8
 558             Line 9
 559             Line 10
 560            -Line 11
 561            -Line 12
 562            -Line 13
 563            +Middle A
 564            +Middle B
 565             Line 14
 566            +Below
 567             Line 15
 568             Line 16
 569             Line 17
 570
 571            ---
 572            @@ -23,6 +23,7 @@
 573             Line 22
 574             Line 23
 575             Line 24
 576            +Far below
 577             Line 25
 578             Line 26
 579             Line 27
 580        "},
 581        "After inserting far below (should NOT coalesce)"
 582    );
 583}
 584
 585fn render_events(events: &[StoredEvent]) -> String {
 586    events
 587        .iter()
 588        .map(|e| {
 589            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 590            diff.as_str()
 591        })
 592        .collect::<Vec<_>>()
 593        .join("\n---\n")
 594}
 595
 596#[gpui::test]
 597async fn test_empty_prediction(cx: &mut TestAppContext) {
 598    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 599    let fs = FakeFs::new(cx.executor());
 600    fs.insert_tree(
 601        "/root",
 602        json!({
 603            "foo.md":  "Hello!\nHow\nBye\n"
 604        }),
 605    )
 606    .await;
 607    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 608
 609    let buffer = project
 610        .update(cx, |project, cx| {
 611            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 612            project.open_buffer(path, cx)
 613        })
 614        .await
 615        .unwrap();
 616    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 617    let position = snapshot.anchor_before(language::Point::new(1, 3));
 618
 619    ep_store.update(cx, |ep_store, cx| {
 620        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 621    });
 622
 623    let (request, respond_tx) = requests.predict.next().await.unwrap();
 624    let response = model_response(&request, "");
 625    let id = response.request_id.clone();
 626    respond_tx.send(response).unwrap();
 627
 628    cx.run_until_parked();
 629
 630    ep_store.update(cx, |ep_store, cx| {
 631        assert!(
 632            ep_store
 633                .prediction_at(&buffer, None, &project, cx)
 634                .is_none()
 635        );
 636    });
 637
 638    // prediction is reported as rejected
 639    let (reject_request, _) = requests.reject.next().await.unwrap();
 640
 641    assert_eq!(
 642        &reject_request.rejections,
 643        &[EditPredictionRejection {
 644            request_id: id,
 645            reason: EditPredictionRejectReason::Empty,
 646            was_shown: false
 647        }]
 648    );
 649}
 650
 651#[gpui::test]
 652async fn test_interpolated_empty(cx: &mut TestAppContext) {
 653    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 654    let fs = FakeFs::new(cx.executor());
 655    fs.insert_tree(
 656        "/root",
 657        json!({
 658            "foo.md":  "Hello!\nHow\nBye\n"
 659        }),
 660    )
 661    .await;
 662    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 663
 664    let buffer = project
 665        .update(cx, |project, cx| {
 666            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 667            project.open_buffer(path, cx)
 668        })
 669        .await
 670        .unwrap();
 671    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 672    let position = snapshot.anchor_before(language::Point::new(1, 3));
 673
 674    ep_store.update(cx, |ep_store, cx| {
 675        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 676    });
 677
 678    let (request, respond_tx) = requests.predict.next().await.unwrap();
 679
 680    buffer.update(cx, |buffer, cx| {
 681        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 682    });
 683
 684    let response = model_response(&request, SIMPLE_DIFF);
 685    let id = response.request_id.clone();
 686    respond_tx.send(response).unwrap();
 687
 688    cx.run_until_parked();
 689
 690    ep_store.update(cx, |ep_store, cx| {
 691        assert!(
 692            ep_store
 693                .prediction_at(&buffer, None, &project, cx)
 694                .is_none()
 695        );
 696    });
 697
 698    // prediction is reported as rejected
 699    let (reject_request, _) = requests.reject.next().await.unwrap();
 700
 701    assert_eq!(
 702        &reject_request.rejections,
 703        &[EditPredictionRejection {
 704            request_id: id,
 705            reason: EditPredictionRejectReason::InterpolatedEmpty,
 706            was_shown: false
 707        }]
 708    );
 709}
 710
 711const SIMPLE_DIFF: &str = indoc! { r"
 712    --- a/root/foo.md
 713    +++ b/root/foo.md
 714    @@ ... @@
 715     Hello!
 716    -How
 717    +How are you?
 718     Bye
 719"};
 720
 721#[gpui::test]
 722async fn test_replace_current(cx: &mut TestAppContext) {
 723    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 724    let fs = FakeFs::new(cx.executor());
 725    fs.insert_tree(
 726        "/root",
 727        json!({
 728            "foo.md":  "Hello!\nHow\nBye\n"
 729        }),
 730    )
 731    .await;
 732    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 733
 734    let buffer = project
 735        .update(cx, |project, cx| {
 736            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 737            project.open_buffer(path, cx)
 738        })
 739        .await
 740        .unwrap();
 741    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 742    let position = snapshot.anchor_before(language::Point::new(1, 3));
 743
 744    ep_store.update(cx, |ep_store, cx| {
 745        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 746    });
 747
 748    let (request, respond_tx) = requests.predict.next().await.unwrap();
 749    let first_response = model_response(&request, SIMPLE_DIFF);
 750    let first_id = first_response.request_id.clone();
 751    respond_tx.send(first_response).unwrap();
 752
 753    cx.run_until_parked();
 754
 755    ep_store.update(cx, |ep_store, cx| {
 756        assert_eq!(
 757            ep_store
 758                .prediction_at(&buffer, None, &project, cx)
 759                .unwrap()
 760                .id
 761                .0,
 762            first_id
 763        );
 764    });
 765
 766    // a second request is triggered
 767    ep_store.update(cx, |ep_store, cx| {
 768        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 769    });
 770
 771    let (request, respond_tx) = requests.predict.next().await.unwrap();
 772    let second_response = model_response(&request, SIMPLE_DIFF);
 773    let second_id = second_response.request_id.clone();
 774    respond_tx.send(second_response).unwrap();
 775
 776    cx.run_until_parked();
 777
 778    ep_store.update(cx, |ep_store, cx| {
 779        // second replaces first
 780        assert_eq!(
 781            ep_store
 782                .prediction_at(&buffer, None, &project, cx)
 783                .unwrap()
 784                .id
 785                .0,
 786            second_id
 787        );
 788    });
 789
 790    // first is reported as replaced
 791    let (reject_request, _) = requests.reject.next().await.unwrap();
 792
 793    assert_eq!(
 794        &reject_request.rejections,
 795        &[EditPredictionRejection {
 796            request_id: first_id,
 797            reason: EditPredictionRejectReason::Replaced,
 798            was_shown: false
 799        }]
 800    );
 801}
 802
 803#[gpui::test]
 804async fn test_current_preferred(cx: &mut TestAppContext) {
 805    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 806    let fs = FakeFs::new(cx.executor());
 807    fs.insert_tree(
 808        "/root",
 809        json!({
 810            "foo.md":  "Hello!\nHow\nBye\n"
 811        }),
 812    )
 813    .await;
 814    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 815
 816    let buffer = project
 817        .update(cx, |project, cx| {
 818            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 819            project.open_buffer(path, cx)
 820        })
 821        .await
 822        .unwrap();
 823    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 824    let position = snapshot.anchor_before(language::Point::new(1, 3));
 825
 826    ep_store.update(cx, |ep_store, cx| {
 827        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 828    });
 829
 830    let (request, respond_tx) = requests.predict.next().await.unwrap();
 831    let first_response = model_response(&request, SIMPLE_DIFF);
 832    let first_id = first_response.request_id.clone();
 833    respond_tx.send(first_response).unwrap();
 834
 835    cx.run_until_parked();
 836
 837    ep_store.update(cx, |ep_store, cx| {
 838        assert_eq!(
 839            ep_store
 840                .prediction_at(&buffer, None, &project, cx)
 841                .unwrap()
 842                .id
 843                .0,
 844            first_id
 845        );
 846    });
 847
 848    // a second request is triggered
 849    ep_store.update(cx, |ep_store, cx| {
 850        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 851    });
 852
 853    let (request, respond_tx) = requests.predict.next().await.unwrap();
 854    // worse than current prediction
 855    let second_response = model_response(
 856        &request,
 857        indoc! { r"
 858            --- a/root/foo.md
 859            +++ b/root/foo.md
 860            @@ ... @@
 861             Hello!
 862            -How
 863            +How are
 864             Bye
 865        "},
 866    );
 867    let second_id = second_response.request_id.clone();
 868    respond_tx.send(second_response).unwrap();
 869
 870    cx.run_until_parked();
 871
 872    ep_store.update(cx, |ep_store, cx| {
 873        // first is preferred over second
 874        assert_eq!(
 875            ep_store
 876                .prediction_at(&buffer, None, &project, cx)
 877                .unwrap()
 878                .id
 879                .0,
 880            first_id
 881        );
 882    });
 883
 884    // second is reported as rejected
 885    let (reject_request, _) = requests.reject.next().await.unwrap();
 886
 887    assert_eq!(
 888        &reject_request.rejections,
 889        &[EditPredictionRejection {
 890            request_id: second_id,
 891            reason: EditPredictionRejectReason::CurrentPreferred,
 892            was_shown: false
 893        }]
 894    );
 895}
 896
 897#[gpui::test]
 898async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
 899    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 900    let fs = FakeFs::new(cx.executor());
 901    fs.insert_tree(
 902        "/root",
 903        json!({
 904            "foo.md":  "Hello!\nHow\nBye\n"
 905        }),
 906    )
 907    .await;
 908    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 909
 910    let buffer = project
 911        .update(cx, |project, cx| {
 912            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 913            project.open_buffer(path, cx)
 914        })
 915        .await
 916        .unwrap();
 917    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 918    let position = snapshot.anchor_before(language::Point::new(1, 3));
 919
 920    // start two refresh tasks
 921    ep_store.update(cx, |ep_store, cx| {
 922        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 923    });
 924
 925    let (request1, respond_first) = requests.predict.next().await.unwrap();
 926
 927    ep_store.update(cx, |ep_store, cx| {
 928        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 929    });
 930
 931    let (request, respond_second) = requests.predict.next().await.unwrap();
 932
 933    // wait for throttle
 934    cx.run_until_parked();
 935
 936    // second responds first
 937    let second_response = model_response(&request, SIMPLE_DIFF);
 938    let second_id = second_response.request_id.clone();
 939    respond_second.send(second_response).unwrap();
 940
 941    cx.run_until_parked();
 942
 943    ep_store.update(cx, |ep_store, cx| {
 944        // current prediction is second
 945        assert_eq!(
 946            ep_store
 947                .prediction_at(&buffer, None, &project, cx)
 948                .unwrap()
 949                .id
 950                .0,
 951            second_id
 952        );
 953    });
 954
 955    let first_response = model_response(&request1, SIMPLE_DIFF);
 956    let first_id = first_response.request_id.clone();
 957    respond_first.send(first_response).unwrap();
 958
 959    cx.run_until_parked();
 960
 961    ep_store.update(cx, |ep_store, cx| {
 962        // current prediction is still second, since first was cancelled
 963        assert_eq!(
 964            ep_store
 965                .prediction_at(&buffer, None, &project, cx)
 966                .unwrap()
 967                .id
 968                .0,
 969            second_id
 970        );
 971    });
 972
 973    // first is reported as rejected
 974    let (reject_request, _) = requests.reject.next().await.unwrap();
 975
 976    cx.run_until_parked();
 977
 978    assert_eq!(
 979        &reject_request.rejections,
 980        &[EditPredictionRejection {
 981            request_id: first_id,
 982            reason: EditPredictionRejectReason::Canceled,
 983            was_shown: false
 984        }]
 985    );
 986}
 987
 988#[gpui::test]
 989async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
 990    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 991    let fs = FakeFs::new(cx.executor());
 992    fs.insert_tree(
 993        "/root",
 994        json!({
 995            "foo.md":  "Hello!\nHow\nBye\n"
 996        }),
 997    )
 998    .await;
 999    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1000
1001    let buffer = project
1002        .update(cx, |project, cx| {
1003            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1004            project.open_buffer(path, cx)
1005        })
1006        .await
1007        .unwrap();
1008    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1009    let position = snapshot.anchor_before(language::Point::new(1, 3));
1010
1011    // start two refresh tasks
1012    ep_store.update(cx, |ep_store, cx| {
1013        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1014    });
1015
1016    let (request1, respond_first) = requests.predict.next().await.unwrap();
1017
1018    ep_store.update(cx, |ep_store, cx| {
1019        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1020    });
1021
1022    let (request2, respond_second) = requests.predict.next().await.unwrap();
1023
1024    // wait for throttle, so requests are sent
1025    cx.run_until_parked();
1026
1027    ep_store.update(cx, |ep_store, cx| {
1028        // start a third request
1029        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1030
1031        // 2 are pending, so 2nd is cancelled
1032        assert_eq!(
1033            ep_store
1034                .get_or_init_project(&project, cx)
1035                .cancelled_predictions
1036                .iter()
1037                .copied()
1038                .collect::<Vec<_>>(),
1039            [1]
1040        );
1041    });
1042
1043    // wait for throttle
1044    cx.run_until_parked();
1045
1046    let (request3, respond_third) = requests.predict.next().await.unwrap();
1047
1048    let first_response = model_response(&request1, SIMPLE_DIFF);
1049    let first_id = first_response.request_id.clone();
1050    respond_first.send(first_response).unwrap();
1051
1052    cx.run_until_parked();
1053
1054    ep_store.update(cx, |ep_store, cx| {
1055        // current prediction is first
1056        assert_eq!(
1057            ep_store
1058                .prediction_at(&buffer, None, &project, cx)
1059                .unwrap()
1060                .id
1061                .0,
1062            first_id
1063        );
1064    });
1065
1066    let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1067    let cancelled_id = cancelled_response.request_id.clone();
1068    respond_second.send(cancelled_response).unwrap();
1069
1070    cx.run_until_parked();
1071
1072    ep_store.update(cx, |ep_store, cx| {
1073        // current prediction is still first, since second was cancelled
1074        assert_eq!(
1075            ep_store
1076                .prediction_at(&buffer, None, &project, cx)
1077                .unwrap()
1078                .id
1079                .0,
1080            first_id
1081        );
1082    });
1083
1084    let third_response = model_response(&request3, SIMPLE_DIFF);
1085    let third_response_id = third_response.request_id.clone();
1086    respond_third.send(third_response).unwrap();
1087
1088    cx.run_until_parked();
1089
1090    ep_store.update(cx, |ep_store, cx| {
1091        // third completes and replaces first
1092        assert_eq!(
1093            ep_store
1094                .prediction_at(&buffer, None, &project, cx)
1095                .unwrap()
1096                .id
1097                .0,
1098            third_response_id
1099        );
1100    });
1101
1102    // second is reported as rejected
1103    let (reject_request, _) = requests.reject.next().await.unwrap();
1104
1105    cx.run_until_parked();
1106
1107    assert_eq!(
1108        &reject_request.rejections,
1109        &[
1110            EditPredictionRejection {
1111                request_id: cancelled_id,
1112                reason: EditPredictionRejectReason::Canceled,
1113                was_shown: false
1114            },
1115            EditPredictionRejection {
1116                request_id: first_id,
1117                reason: EditPredictionRejectReason::Replaced,
1118                was_shown: false
1119            }
1120        ]
1121    );
1122}
1123
1124#[gpui::test]
1125async fn test_rejections_flushing(cx: &mut TestAppContext) {
1126    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1127
1128    ep_store.update(cx, |ep_store, cx| {
1129        ep_store.reject_prediction(
1130            EditPredictionId("test-1".into()),
1131            EditPredictionRejectReason::Discarded,
1132            false,
1133            cx,
1134        );
1135        ep_store.reject_prediction(
1136            EditPredictionId("test-2".into()),
1137            EditPredictionRejectReason::Canceled,
1138            true,
1139            cx,
1140        );
1141    });
1142
1143    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1144    cx.run_until_parked();
1145
1146    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1147    respond_tx.send(()).unwrap();
1148
1149    // batched
1150    assert_eq!(reject_request.rejections.len(), 2);
1151    assert_eq!(
1152        reject_request.rejections[0],
1153        EditPredictionRejection {
1154            request_id: "test-1".to_string(),
1155            reason: EditPredictionRejectReason::Discarded,
1156            was_shown: false
1157        }
1158    );
1159    assert_eq!(
1160        reject_request.rejections[1],
1161        EditPredictionRejection {
1162            request_id: "test-2".to_string(),
1163            reason: EditPredictionRejectReason::Canceled,
1164            was_shown: true
1165        }
1166    );
1167
1168    // Reaching batch size limit sends without debounce
1169    ep_store.update(cx, |ep_store, cx| {
1170        for i in 0..70 {
1171            ep_store.reject_prediction(
1172                EditPredictionId(format!("batch-{}", i).into()),
1173                EditPredictionRejectReason::Discarded,
1174                false,
1175                cx,
1176            );
1177        }
1178    });
1179
1180    // First MAX/2 items are sent immediately
1181    cx.run_until_parked();
1182    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1183    respond_tx.send(()).unwrap();
1184
1185    assert_eq!(reject_request.rejections.len(), 50);
1186    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1187    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1188
1189    // Remaining items are debounced with the next batch
1190    cx.executor().advance_clock(Duration::from_secs(15));
1191    cx.run_until_parked();
1192
1193    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1194    respond_tx.send(()).unwrap();
1195
1196    assert_eq!(reject_request.rejections.len(), 20);
1197    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1198    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1199
1200    // Request failure
1201    ep_store.update(cx, |ep_store, cx| {
1202        ep_store.reject_prediction(
1203            EditPredictionId("retry-1".into()),
1204            EditPredictionRejectReason::Discarded,
1205            false,
1206            cx,
1207        );
1208    });
1209
1210    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1211    cx.run_until_parked();
1212
1213    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1214    assert_eq!(reject_request.rejections.len(), 1);
1215    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1216    // Simulate failure
1217    drop(_respond_tx);
1218
1219    // Add another rejection
1220    ep_store.update(cx, |ep_store, cx| {
1221        ep_store.reject_prediction(
1222            EditPredictionId("retry-2".into()),
1223            EditPredictionRejectReason::Discarded,
1224            false,
1225            cx,
1226        );
1227    });
1228
1229    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1230    cx.run_until_parked();
1231
1232    // Retry should include both the failed item and the new one
1233    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1234    respond_tx.send(()).unwrap();
1235
1236    assert_eq!(reject_request.rejections.len(), 2);
1237    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1238    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1239}
1240
1241// Skipped until we start including diagnostics in prompt
1242// #[gpui::test]
1243// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1244//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1245//     let fs = FakeFs::new(cx.executor());
1246//     fs.insert_tree(
1247//         "/root",
1248//         json!({
1249//             "foo.md": "Hello!\nBye"
1250//         }),
1251//     )
1252//     .await;
1253//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1254
1255//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1256//     let diagnostic = lsp::Diagnostic {
1257//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1258//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1259//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1260//         ..Default::default()
1261//     };
1262
1263//     project.update(cx, |project, cx| {
1264//         project.lsp_store().update(cx, |lsp_store, cx| {
1265//             // Create some diagnostics
1266//             lsp_store
1267//                 .update_diagnostics(
1268//                     LanguageServerId(0),
1269//                     lsp::PublishDiagnosticsParams {
1270//                         uri: path_to_buffer_uri.clone(),
1271//                         diagnostics: vec![diagnostic],
1272//                         version: None,
1273//                     },
1274//                     None,
1275//                     language::DiagnosticSourceKind::Pushed,
1276//                     &[],
1277//                     cx,
1278//                 )
1279//                 .unwrap();
1280//         });
1281//     });
1282
1283//     let buffer = project
1284//         .update(cx, |project, cx| {
1285//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1286//             project.open_buffer(path, cx)
1287//         })
1288//         .await
1289//         .unwrap();
1290
1291//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1292//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1293
1294//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1295//         ep_store.request_prediction(&project, &buffer, position, cx)
1296//     });
1297
1298//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1299
1300//     assert_eq!(request.diagnostic_groups.len(), 1);
1301//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1302//         .unwrap();
1303//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1304//     assert_eq!(
1305//         value,
1306//         json!({
1307//             "entries": [{
1308//                 "range": {
1309//                     "start": 8,
1310//                     "end": 10
1311//                 },
1312//                 "diagnostic": {
1313//                     "source": null,
1314//                     "code": null,
1315//                     "code_description": null,
1316//                     "severity": 1,
1317//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1318//                     "markdown": null,
1319//                     "group_id": 0,
1320//                     "is_primary": true,
1321//                     "is_disk_based": false,
1322//                     "is_unnecessary": false,
1323//                     "source_kind": "Pushed",
1324//                     "data": null,
1325//                     "underline": true
1326//                 }
1327//             }],
1328//             "primary_ix": 0
1329//         })
1330//     );
1331// }
1332
1333// Generate a model response that would apply the given diff to the active file.
1334fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1335    let excerpt =
1336        request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1337    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1338
1339    PredictEditsV3Response {
1340        request_id: Uuid::new_v4().to_string(),
1341        output: new_excerpt,
1342    }
1343}
1344
1345fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1346    zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
1347}
1348
1349struct RequestChannels {
1350    predict: mpsc::UnboundedReceiver<(
1351        PredictEditsV3Request,
1352        oneshot::Sender<PredictEditsV3Response>,
1353    )>,
1354    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1355}
1356
1357fn init_test_with_fake_client(
1358    cx: &mut TestAppContext,
1359) -> (Entity<EditPredictionStore>, RequestChannels) {
1360    cx.update(move |cx| {
1361        let settings_store = SettingsStore::test(cx);
1362        cx.set_global(settings_store);
1363        zlog::init_test();
1364
1365        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1366        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1367
1368        let http_client = FakeHttpClient::create({
1369            move |req| {
1370                let uri = req.uri().path().to_string();
1371                let mut body = req.into_body();
1372                let predict_req_tx = predict_req_tx.clone();
1373                let reject_req_tx = reject_req_tx.clone();
1374                async move {
1375                    let resp = match uri.as_str() {
1376                        "/client/llm_tokens" => serde_json::to_string(&json!({
1377                            "token": "test"
1378                        }))
1379                        .unwrap(),
1380                        "/predict_edits/v3" => {
1381                            let mut buf = Vec::new();
1382                            body.read_to_end(&mut buf).await.ok();
1383                            let decompressed = zstd::decode_all(&buf[..]).unwrap();
1384                            let req = serde_json::from_slice(&decompressed).unwrap();
1385
1386                            let (res_tx, res_rx) = oneshot::channel();
1387                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1388                            serde_json::to_string(&res_rx.await?).unwrap()
1389                        }
1390                        "/predict_edits/reject" => {
1391                            let mut buf = Vec::new();
1392                            body.read_to_end(&mut buf).await.ok();
1393                            let req = serde_json::from_slice(&buf).unwrap();
1394
1395                            let (res_tx, res_rx) = oneshot::channel();
1396                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1397                            serde_json::to_string(&res_rx.await?).unwrap()
1398                        }
1399                        _ => {
1400                            panic!("Unexpected path: {}", uri)
1401                        }
1402                    };
1403
1404                    Ok(Response::builder().body(resp.into()).unwrap())
1405                }
1406            }
1407        });
1408
1409        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1410        client.cloud_client().set_credentials(1, "test".into());
1411
1412        language_model::init(client.clone(), cx);
1413
1414        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1415        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1416
1417        (
1418            ep_store,
1419            RequestChannels {
1420                predict: predict_req_rx,
1421                reject: reject_req_rx,
1422            },
1423        )
1424    })
1425}
1426
1427const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1428
1429#[gpui::test]
1430async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1431    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1432    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1433        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1434    });
1435
1436    let edit_preview = cx
1437        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1438        .await;
1439
1440    let prediction = EditPrediction {
1441        edits,
1442        cursor_position: None,
1443        edit_preview,
1444        buffer: buffer.clone(),
1445        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1446        id: EditPredictionId("the-id".into()),
1447        inputs: ZetaPromptInput {
1448            events: Default::default(),
1449            related_files: Default::default(),
1450            cursor_path: Path::new("").into(),
1451            cursor_excerpt: "".into(),
1452            editable_range_in_excerpt: 0..0,
1453            cursor_offset_in_excerpt: 0,
1454            excerpt_start_row: None,
1455        },
1456        buffer_snapshotted_at: Instant::now(),
1457        response_received_at: Instant::now(),
1458    };
1459
1460    cx.update(|cx| {
1461        assert_eq!(
1462            from_completion_edits(
1463                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1464                &buffer,
1465                cx
1466            ),
1467            vec![(2..5, "REM".into()), (9..11, "".into())]
1468        );
1469
1470        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1471        assert_eq!(
1472            from_completion_edits(
1473                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1474                &buffer,
1475                cx
1476            ),
1477            vec![(2..2, "REM".into()), (6..8, "".into())]
1478        );
1479
1480        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1481        assert_eq!(
1482            from_completion_edits(
1483                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1484                &buffer,
1485                cx
1486            ),
1487            vec![(2..5, "REM".into()), (9..11, "".into())]
1488        );
1489
1490        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1491        assert_eq!(
1492            from_completion_edits(
1493                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1494                &buffer,
1495                cx
1496            ),
1497            vec![(3..3, "EM".into()), (7..9, "".into())]
1498        );
1499
1500        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1501        assert_eq!(
1502            from_completion_edits(
1503                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1504                &buffer,
1505                cx
1506            ),
1507            vec![(4..4, "M".into()), (8..10, "".into())]
1508        );
1509
1510        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1511        assert_eq!(
1512            from_completion_edits(
1513                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1514                &buffer,
1515                cx
1516            ),
1517            vec![(9..11, "".into())]
1518        );
1519
1520        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1521        assert_eq!(
1522            from_completion_edits(
1523                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1524                &buffer,
1525                cx
1526            ),
1527            vec![(4..4, "M".into()), (8..10, "".into())]
1528        );
1529
1530        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1531        assert_eq!(
1532            from_completion_edits(
1533                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1534                &buffer,
1535                cx
1536            ),
1537            vec![(4..4, "M".into())]
1538        );
1539
1540        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1541        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1542    })
1543}
1544
1545#[gpui::test]
1546async fn test_clean_up_diff(cx: &mut TestAppContext) {
1547    init_test(cx);
1548
1549    assert_eq!(
1550        apply_edit_prediction(
1551            indoc! {"
1552                    fn main() {
1553                        let word_1 = \"lorem\";
1554                        let range = word.len()..word.len();
1555                    }
1556                "},
1557            indoc! {"
1558                    <|editable_region_start|>
1559                    fn main() {
1560                        let word_1 = \"lorem\";
1561                        let range = word_1.len()..word_1.len();
1562                    }
1563
1564                    <|editable_region_end|>
1565                "},
1566            cx,
1567        )
1568        .await,
1569        indoc! {"
1570                fn main() {
1571                    let word_1 = \"lorem\";
1572                    let range = word_1.len()..word_1.len();
1573                }
1574            "},
1575    );
1576
1577    assert_eq!(
1578        apply_edit_prediction(
1579            indoc! {"
1580                    fn main() {
1581                        let story = \"the quick\"
1582                    }
1583                "},
1584            indoc! {"
1585                    <|editable_region_start|>
1586                    fn main() {
1587                        let story = \"the quick brown fox jumps over the lazy dog\";
1588                    }
1589
1590                    <|editable_region_end|>
1591                "},
1592            cx,
1593        )
1594        .await,
1595        indoc! {"
1596                fn main() {
1597                    let story = \"the quick brown fox jumps over the lazy dog\";
1598                }
1599            "},
1600    );
1601}
1602
1603#[gpui::test]
1604async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1605    init_test(cx);
1606
1607    let buffer_content = "lorem\n";
1608    let completion_response = indoc! {"
1609            ```animals.js
1610            <|start_of_file|>
1611            <|editable_region_start|>
1612            lorem
1613            ipsum
1614            <|editable_region_end|>
1615            ```"};
1616
1617    assert_eq!(
1618        apply_edit_prediction(buffer_content, completion_response, cx).await,
1619        "lorem\nipsum"
1620    );
1621}
1622
1623#[gpui::test]
1624async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1625    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1626    // When the buffer ends without a trailing newline, but the model returns output
1627    // with a trailing newline, zeta2 should normalize both sides before diffing
1628    // so no spurious newline is inserted.
1629    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1630    let fs = FakeFs::new(cx.executor());
1631
1632    // Single line buffer with no trailing newline
1633    fs.insert_tree(
1634        "/root",
1635        json!({
1636            "foo.txt": "hello"
1637        }),
1638    )
1639    .await;
1640    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1641
1642    let buffer = project
1643        .update(cx, |project, cx| {
1644            let path = project
1645                .find_project_path(path!("root/foo.txt"), cx)
1646                .unwrap();
1647            project.open_buffer(path, cx)
1648        })
1649        .await
1650        .unwrap();
1651
1652    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1653    let position = snapshot.anchor_before(language::Point::new(0, 5));
1654
1655    ep_store.update(cx, |ep_store, cx| {
1656        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1657    });
1658
1659    let (_request, respond_tx) = requests.predict.next().await.unwrap();
1660
1661    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1662    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1663    let response = PredictEditsV3Response {
1664        request_id: Uuid::new_v4().to_string(),
1665        output: "hello world\n".to_string(),
1666    };
1667    respond_tx.send(response).unwrap();
1668
1669    cx.run_until_parked();
1670
1671    // The prediction should insert " world" without adding a newline
1672    ep_store.update(cx, |ep_store, cx| {
1673        let prediction = ep_store
1674            .prediction_at(&buffer, None, &project, cx)
1675            .expect("should have prediction");
1676        let edits: Vec<_> = prediction
1677            .edits
1678            .iter()
1679            .map(|(range, text)| {
1680                let snapshot = buffer.read(cx).snapshot();
1681                (range.to_offset(&snapshot), text.clone())
1682            })
1683            .collect();
1684        assert_eq!(edits, vec![(5..5, " world".into())]);
1685    });
1686}
1687
1688#[gpui::test]
1689async fn test_can_collect_data(cx: &mut TestAppContext) {
1690    init_test(cx);
1691
1692    let fs = project::FakeFs::new(cx.executor());
1693    fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1694        .await;
1695
1696    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1697    let buffer = project
1698        .update(cx, |project, cx| {
1699            project.open_local_buffer(path!("/project/src/main.rs"), cx)
1700        })
1701        .await
1702        .unwrap();
1703
1704    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1705    ep_store.update(cx, |ep_store, _cx| {
1706        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1707    });
1708
1709    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1710    assert_eq!(
1711        captured_request.lock().clone().unwrap().can_collect_data,
1712        true
1713    );
1714
1715    ep_store.update(cx, |ep_store, _cx| {
1716        ep_store.data_collection_choice = DataCollectionChoice::Disabled
1717    });
1718
1719    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1720    assert_eq!(
1721        captured_request.lock().clone().unwrap().can_collect_data,
1722        false
1723    );
1724}
1725
1726#[gpui::test]
1727async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1728    init_test(cx);
1729
1730    let fs = project::FakeFs::new(cx.executor());
1731    let project = Project::test(fs.clone(), [], cx).await;
1732
1733    let buffer = cx.new(|_cx| {
1734        Buffer::remote(
1735            language::BufferId::new(1).unwrap(),
1736            ReplicaId::new(1),
1737            language::Capability::ReadWrite,
1738            "fn main() {\n    println!(\"Hello\");\n}",
1739        )
1740    });
1741
1742    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1743    ep_store.update(cx, |ep_store, _cx| {
1744        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1745    });
1746
1747    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1748    assert_eq!(
1749        captured_request.lock().clone().unwrap().can_collect_data,
1750        false
1751    );
1752}
1753
1754#[gpui::test]
1755async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1756    init_test(cx);
1757
1758    let fs = project::FakeFs::new(cx.executor());
1759    fs.insert_tree(
1760        path!("/project"),
1761        json!({
1762            "LICENSE": BSD_0_TXT,
1763            ".env": "SECRET_KEY=secret"
1764        }),
1765    )
1766    .await;
1767
1768    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1769    let buffer = project
1770        .update(cx, |project, cx| {
1771            project.open_local_buffer("/project/.env", cx)
1772        })
1773        .await
1774        .unwrap();
1775
1776    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1777    ep_store.update(cx, |ep_store, _cx| {
1778        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1779    });
1780
1781    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1782    assert_eq!(
1783        captured_request.lock().clone().unwrap().can_collect_data,
1784        false
1785    );
1786}
1787
1788#[gpui::test]
1789async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1790    init_test(cx);
1791
1792    let fs = project::FakeFs::new(cx.executor());
1793    let project = Project::test(fs.clone(), [], cx).await;
1794    let buffer = cx.new(|cx| Buffer::local("", cx));
1795
1796    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1797    ep_store.update(cx, |ep_store, _cx| {
1798        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1799    });
1800
1801    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1802    assert_eq!(
1803        captured_request.lock().clone().unwrap().can_collect_data,
1804        false
1805    );
1806}
1807
1808#[gpui::test]
1809async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1810    init_test(cx);
1811
1812    let fs = project::FakeFs::new(cx.executor());
1813    fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1814        .await;
1815
1816    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1817    let buffer = project
1818        .update(cx, |project, cx| {
1819            project.open_local_buffer("/project/main.rs", cx)
1820        })
1821        .await
1822        .unwrap();
1823
1824    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1825    ep_store.update(cx, |ep_store, _cx| {
1826        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1827    });
1828
1829    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1830    assert_eq!(
1831        captured_request.lock().clone().unwrap().can_collect_data,
1832        false
1833    );
1834}
1835
1836#[gpui::test]
1837async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1838    init_test(cx);
1839
1840    let fs = project::FakeFs::new(cx.executor());
1841    fs.insert_tree(
1842        path!("/open_source_worktree"),
1843        json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1844    )
1845    .await;
1846    fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1847        .await;
1848
1849    let project = Project::test(
1850        fs.clone(),
1851        [
1852            path!("/open_source_worktree").as_ref(),
1853            path!("/closed_source_worktree").as_ref(),
1854        ],
1855        cx,
1856    )
1857    .await;
1858    let buffer = project
1859        .update(cx, |project, cx| {
1860            project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1861        })
1862        .await
1863        .unwrap();
1864
1865    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1866    ep_store.update(cx, |ep_store, _cx| {
1867        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1868    });
1869
1870    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1871    assert_eq!(
1872        captured_request.lock().clone().unwrap().can_collect_data,
1873        true
1874    );
1875
1876    let closed_source_file = project
1877        .update(cx, |project, cx| {
1878            let worktree2 = project
1879                .worktree_for_root_name("closed_source_worktree", cx)
1880                .unwrap();
1881            worktree2.update(cx, |worktree2, cx| {
1882                worktree2.load_file(rel_path("main.rs"), cx)
1883            })
1884        })
1885        .await
1886        .unwrap()
1887        .file;
1888
1889    buffer.update(cx, |buffer, cx| {
1890        buffer.file_updated(closed_source_file, cx);
1891    });
1892
1893    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1894    assert_eq!(
1895        captured_request.lock().clone().unwrap().can_collect_data,
1896        false
1897    );
1898}
1899
1900#[gpui::test]
1901async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1902    init_test(cx);
1903
1904    let fs = project::FakeFs::new(cx.executor());
1905    fs.insert_tree(
1906        path!("/worktree1"),
1907        json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1908    )
1909    .await;
1910    fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1911        .await;
1912
1913    let project = Project::test(
1914        fs.clone(),
1915        [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1916        cx,
1917    )
1918    .await;
1919    let buffer = project
1920        .update(cx, |project, cx| {
1921            project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1922        })
1923        .await
1924        .unwrap();
1925    let private_buffer = project
1926        .update(cx, |project, cx| {
1927            project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1928        })
1929        .await
1930        .unwrap();
1931
1932    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1933    ep_store.update(cx, |ep_store, _cx| {
1934        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1935    });
1936
1937    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1938    assert_eq!(
1939        captured_request.lock().clone().unwrap().can_collect_data,
1940        true
1941    );
1942
1943    // this has a side effect of registering the buffer to watch for edits
1944    run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1945    assert_eq!(
1946        captured_request.lock().clone().unwrap().can_collect_data,
1947        false
1948    );
1949
1950    private_buffer.update(cx, |private_buffer, cx| {
1951        private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1952    });
1953
1954    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1955    assert_eq!(
1956        captured_request.lock().clone().unwrap().can_collect_data,
1957        false
1958    );
1959
1960    // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1961    // included
1962    buffer.update(cx, |buffer, cx| {
1963        buffer.edit(
1964            [(
1965                0..0,
1966                " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1967            )],
1968            None,
1969            cx,
1970        );
1971    });
1972
1973    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1974    assert_eq!(
1975        captured_request.lock().clone().unwrap().can_collect_data,
1976        true
1977    );
1978}
1979
1980fn init_test(cx: &mut TestAppContext) {
1981    cx.update(|cx| {
1982        let settings_store = SettingsStore::test(cx);
1983        cx.set_global(settings_store);
1984    });
1985}
1986
1987async fn apply_edit_prediction(
1988    buffer_content: &str,
1989    completion_response: &str,
1990    cx: &mut TestAppContext,
1991) -> String {
1992    let fs = project::FakeFs::new(cx.executor());
1993    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1994    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1995    let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1996    *response.lock() = completion_response.to_string();
1997    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1998    buffer.update(cx, |buffer, cx| {
1999        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2000    });
2001    buffer.read_with(cx, |buffer, _| buffer.text())
2002}
2003
2004async fn run_edit_prediction(
2005    buffer: &Entity<Buffer>,
2006    project: &Entity<Project>,
2007    ep_store: &Entity<EditPredictionStore>,
2008    cx: &mut TestAppContext,
2009) -> EditPrediction {
2010    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2011    ep_store.update(cx, |ep_store, cx| {
2012        ep_store.register_buffer(buffer, &project, cx)
2013    });
2014    cx.background_executor.run_until_parked();
2015    let prediction_task = ep_store.update(cx, |ep_store, cx| {
2016        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2017    });
2018    prediction_task.await.unwrap().unwrap().prediction.unwrap()
2019}
2020
2021async fn make_test_ep_store(
2022    project: &Entity<Project>,
2023    cx: &mut TestAppContext,
2024) -> (
2025    Entity<EditPredictionStore>,
2026    Arc<Mutex<Option<PredictEditsBody>>>,
2027    Arc<Mutex<String>>,
2028) {
2029    let default_response = indoc! {"
2030            ```main.rs
2031            <|start_of_file|>
2032            <|editable_region_start|>
2033            hello world
2034            <|editable_region_end|>
2035            ```"
2036    };
2037    let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2038    let completion_response: Arc<Mutex<String>> =
2039        Arc::new(Mutex::new(default_response.to_string()));
2040    let http_client = FakeHttpClient::create({
2041        let captured_request = captured_request.clone();
2042        let completion_response = completion_response.clone();
2043        let mut next_request_id = 0;
2044        move |req| {
2045            let captured_request = captured_request.clone();
2046            let completion_response = completion_response.clone();
2047            async move {
2048                match (req.method(), req.uri().path()) {
2049                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2050                        .status(200)
2051                        .body(
2052                            serde_json::to_string(&CreateLlmTokenResponse {
2053                                token: LlmToken("the-llm-token".to_string()),
2054                            })
2055                            .unwrap()
2056                            .into(),
2057                        )
2058                        .unwrap()),
2059                    (&Method::POST, "/predict_edits/v2") => {
2060                        let mut request_body = String::new();
2061                        req.into_body().read_to_string(&mut request_body).await?;
2062                        *captured_request.lock() =
2063                            Some(serde_json::from_str(&request_body).unwrap());
2064                        next_request_id += 1;
2065                        Ok(http_client::Response::builder()
2066                            .status(200)
2067                            .body(
2068                                serde_json::to_string(&PredictEditsResponse {
2069                                    request_id: format!("request-{next_request_id}"),
2070                                    output_excerpt: completion_response.lock().clone(),
2071                                })
2072                                .unwrap()
2073                                .into(),
2074                            )
2075                            .unwrap())
2076                    }
2077                    (&Method::POST, "/predict_edits/v3") => {
2078                        next_request_id += 1;
2079                        Ok(http_client::Response::builder()
2080                            .status(200)
2081                            .body(
2082                                serde_json::to_string(&PredictEditsV3Response {
2083                                    request_id: format!("request-{next_request_id}"),
2084                                    output: "hello world".to_string(),
2085                                })
2086                                .unwrap()
2087                                .into(),
2088                            )
2089                            .unwrap())
2090                    }
2091                    _ => Ok(http_client::Response::builder()
2092                        .status(404)
2093                        .body("Not Found".into())
2094                        .unwrap()),
2095                }
2096            }
2097        }
2098    });
2099
2100    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2101    cx.update(|cx| {
2102        RefreshLlmTokenListener::register(client.clone(), cx);
2103    });
2104    let _server = FakeServer::for_client(42, &client, cx).await;
2105
2106    let ep_store = cx.new(|cx| {
2107        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2108        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2109
2110        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2111        for worktree in worktrees {
2112            let worktree_id = worktree.read(cx).id();
2113            ep_store
2114                .get_or_init_project(project, cx)
2115                .license_detection_watchers
2116                .entry(worktree_id)
2117                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2118        }
2119
2120        ep_store
2121    });
2122
2123    (ep_store, captured_request, completion_response)
2124}
2125
2126fn to_completion_edits(
2127    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2128    buffer: &Entity<Buffer>,
2129    cx: &App,
2130) -> Vec<(Range<Anchor>, Arc<str>)> {
2131    let buffer = buffer.read(cx);
2132    iterator
2133        .into_iter()
2134        .map(|(range, text)| {
2135            (
2136                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2137                text,
2138            )
2139        })
2140        .collect()
2141}
2142
2143fn from_completion_edits(
2144    editor_edits: &[(Range<Anchor>, Arc<str>)],
2145    buffer: &Entity<Buffer>,
2146    cx: &App,
2147) -> Vec<(Range<usize>, Arc<str>)> {
2148    let buffer = buffer.read(cx);
2149    editor_edits
2150        .iter()
2151        .map(|(range, text)| {
2152            (
2153                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2154                text.clone(),
2155            )
2156        })
2157        .collect()
2158}
2159
2160#[gpui::test]
2161async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2162    init_test(cx);
2163
2164    let fs = FakeFs::new(cx.executor());
2165    fs.insert_tree(
2166        "/project",
2167        serde_json::json!({
2168            "main.rs": "fn main() {\n    \n}\n"
2169        }),
2170    )
2171    .await;
2172
2173    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2174
2175    let http_client = FakeHttpClient::create(|_req| async move {
2176        Ok(gpui::http_client::Response::builder()
2177            .status(401)
2178            .body("Unauthorized".into())
2179            .unwrap())
2180    });
2181
2182    let client =
2183        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2184    cx.update(|cx| {
2185        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2186    });
2187
2188    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2189
2190    let buffer = project
2191        .update(cx, |project, cx| {
2192            let path = project
2193                .find_project_path(path!("/project/main.rs"), cx)
2194                .unwrap();
2195            project.open_buffer(path, cx)
2196        })
2197        .await
2198        .unwrap();
2199
2200    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2201    ep_store.update(cx, |ep_store, cx| {
2202        ep_store.register_buffer(&buffer, &project, cx)
2203    });
2204    cx.background_executor.run_until_parked();
2205
2206    let completion_task = ep_store.update(cx, |ep_store, cx| {
2207        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2208        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2209    });
2210
2211    let result = completion_task.await;
2212    assert!(
2213        result.is_err(),
2214        "Without authentication and without custom URL, prediction should fail"
2215    );
2216}
2217
2218#[gpui::test]
2219fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2220    let buffer = cx.new(|cx| {
2221        Buffer::local(
2222            indoc! {"
2223                zero
2224                one
2225                two
2226                three
2227                four
2228                five
2229                six
2230                seven
2231                eight
2232                nine
2233                ten
2234                eleven
2235                twelve
2236                thirteen
2237                fourteen
2238                fifteen
2239                sixteen
2240                seventeen
2241                eighteen
2242                nineteen
2243                twenty
2244                twenty-one
2245                twenty-two
2246                twenty-three
2247                twenty-four
2248            "},
2249            cx,
2250        )
2251    });
2252
2253    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2254
2255    buffer.update(cx, |buffer, cx| {
2256        let point = Point::new(12, 0);
2257        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2258        let point = Point::new(8, 0);
2259        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2260    });
2261
2262    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2263
2264    let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2265
2266    assert_eq!(
2267        diff,
2268        indoc! {"
2269            @@ -6,10 +6,12 @@
2270             five
2271             six
2272             seven
2273            +FIRST INSERTION
2274             eight
2275             nine
2276             ten
2277             eleven
2278            +SECOND INSERTION
2279             twelve
2280             thirteen
2281             fourteen
2282            "}
2283    );
2284}
2285
2286#[ctor::ctor]
2287fn init_logger() {
2288    zlog::init_test();
2289}