edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
   3use client::{UserStore, test::FakeServer};
   4use clock::{FakeSystemClock, ReplicaId};
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
   8    RejectEditPredictionsBody,
   9    predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
  10};
  11use futures::{
  12    AsyncReadExt, StreamExt,
  13    channel::{mpsc, oneshot},
  14};
  15use gpui::App;
  16use gpui::{
  17    Entity, TestAppContext,
  18    http_client::{FakeHttpClient, Response},
  19};
  20use indoc::indoc;
  21use language::{Buffer, Point};
  22use lsp::LanguageServerId;
  23use parking_lot::Mutex;
  24use pretty_assertions::{assert_eq, assert_matches};
  25use project::{FakeFs, Project};
  26use serde_json::json;
  27use settings::SettingsStore;
  28use std::{path::Path, sync::Arc, time::Duration};
  29use util::{path, rel_path::rel_path};
  30use uuid::Uuid;
  31use zeta_prompt::ZetaPromptInput;
  32
  33use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  34
  35#[gpui::test]
  36async fn test_current_state(cx: &mut TestAppContext) {
  37    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  38    let fs = FakeFs::new(cx.executor());
  39    fs.insert_tree(
  40        "/root",
  41        json!({
  42            "1.txt": "Hello!\nHow\nBye\n",
  43            "2.txt": "Hola!\nComo\nAdios\n"
  44        }),
  45    )
  46    .await;
  47    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  48
  49    let buffer1 = project
  50        .update(cx, |project, cx| {
  51            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  52            project.set_active_path(Some(path.clone()), cx);
  53            project.open_buffer(path, cx)
  54        })
  55        .await
  56        .unwrap();
  57    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  58    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  59
  60    ep_store.update(cx, |ep_store, cx| {
  61        ep_store.register_project(&project, cx);
  62        ep_store.register_buffer(&buffer1, &project, cx);
  63    });
  64
  65    // Prediction for current file
  66
  67    ep_store.update(cx, |ep_store, cx| {
  68        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  69    });
  70    let (request, respond_tx) = requests.predict.next().await.unwrap();
  71
  72    respond_tx
  73        .send(model_response(
  74            &request,
  75            indoc! {r"
  76                --- a/root/1.txt
  77                +++ b/root/1.txt
  78                @@ ... @@
  79                 Hello!
  80                -How
  81                +How are you?
  82                 Bye
  83            "},
  84        ))
  85        .unwrap();
  86
  87    cx.run_until_parked();
  88
  89    ep_store.update(cx, |ep_store, cx| {
  90        let prediction = ep_store
  91            .prediction_at(&buffer1, None, &project, cx)
  92            .unwrap();
  93        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  94    });
  95
  96    ep_store.update(cx, |ep_store, cx| {
  97        ep_store.reject_current_prediction(
  98            EditPredictionRejectReason::Discarded,
  99            &project,
 100            edit_prediction_types::EditPredictionDismissReason::Ignored,
 101            cx,
 102        );
 103    });
 104
 105    // Prediction for diagnostic in another file
 106
 107    let diagnostic = lsp::Diagnostic {
 108        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 109        severity: Some(lsp::DiagnosticSeverity::ERROR),
 110        message: "Sentence is incomplete".to_string(),
 111        ..Default::default()
 112    };
 113
 114    project.update(cx, |project, cx| {
 115        project.lsp_store().update(cx, |lsp_store, cx| {
 116            lsp_store
 117                .update_diagnostics(
 118                    LanguageServerId(0),
 119                    lsp::PublishDiagnosticsParams {
 120                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 121                        diagnostics: vec![diagnostic],
 122                        version: None,
 123                    },
 124                    None,
 125                    language::DiagnosticSourceKind::Pushed,
 126                    &[],
 127                    cx,
 128                )
 129                .unwrap();
 130        });
 131    });
 132
 133    let (request, respond_tx) = requests.predict.next().await.unwrap();
 134    respond_tx
 135        .send(model_response(
 136            &request,
 137            indoc! {r#"
 138                --- a/root/2.txt
 139                +++ b/root/2.txt
 140                @@ ... @@
 141                 Hola!
 142                -Como
 143                +Como estas?
 144                 Adios
 145            "#},
 146        ))
 147        .unwrap();
 148    cx.run_until_parked();
 149
 150    ep_store.update(cx, |ep_store, cx| {
 151        let prediction = ep_store
 152            .prediction_at(&buffer1, None, &project, cx)
 153            .unwrap();
 154        assert_matches!(
 155            prediction,
 156            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 157        );
 158    });
 159
 160    let buffer2 = project
 161        .update(cx, |project, cx| {
 162            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 163            project.open_buffer(path, cx)
 164        })
 165        .await
 166        .unwrap();
 167
 168    ep_store.update(cx, |ep_store, cx| {
 169        let prediction = ep_store
 170            .prediction_at(&buffer2, None, &project, cx)
 171            .unwrap();
 172        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 173    });
 174}
 175
 176#[gpui::test]
 177async fn test_simple_request(cx: &mut TestAppContext) {
 178    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 179    let fs = FakeFs::new(cx.executor());
 180    fs.insert_tree(
 181        "/root",
 182        json!({
 183            "foo.md":  "Hello!\nHow\nBye\n"
 184        }),
 185    )
 186    .await;
 187    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 188
 189    let buffer = project
 190        .update(cx, |project, cx| {
 191            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 192            project.open_buffer(path, cx)
 193        })
 194        .await
 195        .unwrap();
 196    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 197    let position = snapshot.anchor_before(language::Point::new(1, 3));
 198
 199    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 200        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 201    });
 202
 203    let (request, respond_tx) = requests.predict.next().await.unwrap();
 204
 205    // TODO Put back when we have a structured request again
 206    // assert_eq!(
 207    //     request.excerpt_path.as_ref(),
 208    //     Path::new(path!("root/foo.md"))
 209    // );
 210    // assert_eq!(
 211    //     request.cursor_point,
 212    //     Point {
 213    //         line: Line(1),
 214    //         column: 3
 215    //     }
 216    // );
 217
 218    respond_tx
 219        .send(model_response(
 220            &request,
 221            indoc! { r"
 222                --- a/root/foo.md
 223                +++ b/root/foo.md
 224                @@ ... @@
 225                 Hello!
 226                -How
 227                +How are you?
 228                 Bye
 229            "},
 230        ))
 231        .unwrap();
 232
 233    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 234
 235    assert_eq!(prediction.edits.len(), 1);
 236    assert_eq!(
 237        prediction.edits[0].0.to_point(&snapshot).start,
 238        language::Point::new(1, 3)
 239    );
 240    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 241}
 242
 243#[gpui::test]
 244async fn test_request_events(cx: &mut TestAppContext) {
 245    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 246    let fs = FakeFs::new(cx.executor());
 247    fs.insert_tree(
 248        "/root",
 249        json!({
 250            "foo.md": "Hello!\n\nBye\n"
 251        }),
 252    )
 253    .await;
 254    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 255
 256    let buffer = project
 257        .update(cx, |project, cx| {
 258            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 259            project.open_buffer(path, cx)
 260        })
 261        .await
 262        .unwrap();
 263
 264    ep_store.update(cx, |ep_store, cx| {
 265        ep_store.register_buffer(&buffer, &project, cx);
 266    });
 267
 268    buffer.update(cx, |buffer, cx| {
 269        buffer.edit(vec![(7..7, "How")], None, cx);
 270    });
 271
 272    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 273    let position = snapshot.anchor_before(language::Point::new(1, 3));
 274
 275    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 276        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 277    });
 278
 279    let (request, respond_tx) = requests.predict.next().await.unwrap();
 280
 281    let prompt = prompt_from_request(&request);
 282    assert!(
 283        prompt.contains(indoc! {"
 284        --- a/root/foo.md
 285        +++ b/root/foo.md
 286        @@ -1,3 +1,3 @@
 287         Hello!
 288        -
 289        +How
 290         Bye
 291    "}),
 292        "{prompt}"
 293    );
 294
 295    respond_tx
 296        .send(model_response(
 297            &request,
 298            indoc! {r#"
 299                --- a/root/foo.md
 300                +++ b/root/foo.md
 301                @@ ... @@
 302                 Hello!
 303                -How
 304                +How are you?
 305                 Bye
 306        "#},
 307        ))
 308        .unwrap();
 309
 310    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 311
 312    assert_eq!(prediction.edits.len(), 1);
 313    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 314}
 315
 316#[gpui::test]
 317async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 318    let (ep_store, _requests) = init_test_with_fake_client(cx);
 319    let fs = FakeFs::new(cx.executor());
 320    fs.insert_tree(
 321        "/root",
 322        json!({
 323            "foo.md": "Hello!\n\nBye\n"
 324        }),
 325    )
 326    .await;
 327    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 328
 329    let buffer = project
 330        .update(cx, |project, cx| {
 331            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 332            project.open_buffer(path, cx)
 333        })
 334        .await
 335        .unwrap();
 336
 337    ep_store.update(cx, |ep_store, cx| {
 338        ep_store.register_buffer(&buffer, &project, cx);
 339    });
 340
 341    // First burst: insert "How"
 342    buffer.update(cx, |buffer, cx| {
 343        buffer.edit(vec![(7..7, "How")], None, cx);
 344    });
 345
 346    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 347    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 348    cx.run_until_parked();
 349
 350    // Second burst: append " are you?" immediately after "How" on the same line.
 351    //
 352    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 353    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 354    buffer.update(cx, |buffer, cx| {
 355        buffer.edit(vec![(10..10, " are you?")], None, cx);
 356    });
 357
 358    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 359    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 360    buffer.update(cx, |buffer, cx| {
 361        buffer.edit(vec![(19..19, "!")], None, cx);
 362    });
 363
 364    // Without time-based splitting, there is one event.
 365    let events = ep_store.update(cx, |ep_store, cx| {
 366        ep_store.edit_history_for_project(&project, cx)
 367    });
 368    assert_eq!(events.len(), 1);
 369    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 370    assert_eq!(
 371        diff.as_str(),
 372        indoc! {"
 373            @@ -1,3 +1,3 @@
 374             Hello!
 375            -
 376            +How are you?!
 377             Bye
 378        "}
 379    );
 380
 381    // With time-based splitting, there are two distinct events.
 382    let events = ep_store.update(cx, |ep_store, cx| {
 383        ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
 384    });
 385    assert_eq!(events.len(), 2);
 386    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 387    assert_eq!(
 388        diff.as_str(),
 389        indoc! {"
 390            @@ -1,3 +1,3 @@
 391             Hello!
 392            -
 393            +How
 394             Bye
 395        "}
 396    );
 397
 398    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 399    assert_eq!(
 400        diff.as_str(),
 401        indoc! {"
 402            @@ -1,3 +1,3 @@
 403             Hello!
 404            -How
 405            +How are you?!
 406             Bye
 407        "}
 408    );
 409}
 410
 411#[gpui::test]
 412async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
 413    let (ep_store, _requests) = init_test_with_fake_client(cx);
 414    let fs = FakeFs::new(cx.executor());
 415
 416    // Create a file with 30 lines to test line-based coalescing
 417    let content = (1..=30)
 418        .map(|i| format!("Line {}\n", i))
 419        .collect::<String>();
 420    fs.insert_tree(
 421        "/root",
 422        json!({
 423            "foo.md": content
 424        }),
 425    )
 426    .await;
 427    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 428
 429    let buffer = project
 430        .update(cx, |project, cx| {
 431            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 432            project.open_buffer(path, cx)
 433        })
 434        .await
 435        .unwrap();
 436
 437    ep_store.update(cx, |ep_store, cx| {
 438        ep_store.register_buffer(&buffer, &project, cx);
 439    });
 440
 441    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 442    buffer.update(cx, |buffer, cx| {
 443        let start = Point::new(10, 0).to_offset(buffer);
 444        let end = Point::new(13, 0).to_offset(buffer);
 445        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 446    });
 447
 448    let events = ep_store.update(cx, |ep_store, cx| {
 449        ep_store.edit_history_for_project(&project, cx)
 450    });
 451    assert_eq!(
 452        render_events(&events),
 453        indoc! {"
 454            @@ -8,9 +8,8 @@
 455             Line 8
 456             Line 9
 457             Line 10
 458            -Line 11
 459            -Line 12
 460            -Line 13
 461            +Middle A
 462            +Middle B
 463             Line 14
 464             Line 15
 465             Line 16
 466        "},
 467        "After first edit"
 468    );
 469
 470    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 471    // This tests that coalescing considers the START of the existing range
 472    buffer.update(cx, |buffer, cx| {
 473        let offset = Point::new(5, 0).to_offset(buffer);
 474        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 475    });
 476
 477    let events = ep_store.update(cx, |ep_store, cx| {
 478        ep_store.edit_history_for_project(&project, cx)
 479    });
 480    assert_eq!(
 481        render_events(&events),
 482        indoc! {"
 483            @@ -3,14 +3,14 @@
 484             Line 3
 485             Line 4
 486             Line 5
 487            +Above
 488             Line 6
 489             Line 7
 490             Line 8
 491             Line 9
 492             Line 10
 493            -Line 11
 494            -Line 12
 495            -Line 13
 496            +Middle A
 497            +Middle B
 498             Line 14
 499             Line 15
 500             Line 16
 501        "},
 502        "After inserting above (should coalesce)"
 503    );
 504
 505    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 506    // This tests that coalescing considers the END of the existing range
 507    buffer.update(cx, |buffer, cx| {
 508        let offset = Point::new(14, 0).to_offset(buffer);
 509        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 510    });
 511
 512    let events = ep_store.update(cx, |ep_store, cx| {
 513        ep_store.edit_history_for_project(&project, cx)
 514    });
 515    assert_eq!(
 516        render_events(&events),
 517        indoc! {"
 518            @@ -3,15 +3,16 @@
 519             Line 3
 520             Line 4
 521             Line 5
 522            +Above
 523             Line 6
 524             Line 7
 525             Line 8
 526             Line 9
 527             Line 10
 528            -Line 11
 529            -Line 12
 530            -Line 13
 531            +Middle A
 532            +Middle B
 533             Line 14
 534            +Below
 535             Line 15
 536             Line 16
 537             Line 17
 538        "},
 539        "After inserting below (should coalesce)"
 540    );
 541
 542    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 543    // This should NOT coalesce - creates a new event
 544    buffer.update(cx, |buffer, cx| {
 545        let offset = Point::new(25, 0).to_offset(buffer);
 546        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 547    });
 548
 549    let events = ep_store.update(cx, |ep_store, cx| {
 550        ep_store.edit_history_for_project(&project, cx)
 551    });
 552    assert_eq!(
 553        render_events(&events),
 554        indoc! {"
 555            @@ -3,15 +3,16 @@
 556             Line 3
 557             Line 4
 558             Line 5
 559            +Above
 560             Line 6
 561             Line 7
 562             Line 8
 563             Line 9
 564             Line 10
 565            -Line 11
 566            -Line 12
 567            -Line 13
 568            +Middle A
 569            +Middle B
 570             Line 14
 571            +Below
 572             Line 15
 573             Line 16
 574             Line 17
 575
 576            ---
 577            @@ -23,6 +23,7 @@
 578             Line 22
 579             Line 23
 580             Line 24
 581            +Far below
 582             Line 25
 583             Line 26
 584             Line 27
 585        "},
 586        "After inserting far below (should NOT coalesce)"
 587    );
 588}
 589
 590fn render_events(events: &[StoredEvent]) -> String {
 591    events
 592        .iter()
 593        .map(|e| {
 594            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 595            diff.as_str()
 596        })
 597        .collect::<Vec<_>>()
 598        .join("\n---\n")
 599}
 600
 601#[gpui::test]
 602async fn test_empty_prediction(cx: &mut TestAppContext) {
 603    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 604    let fs = FakeFs::new(cx.executor());
 605    fs.insert_tree(
 606        "/root",
 607        json!({
 608            "foo.md":  "Hello!\nHow\nBye\n"
 609        }),
 610    )
 611    .await;
 612    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 613
 614    let buffer = project
 615        .update(cx, |project, cx| {
 616            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 617            project.open_buffer(path, cx)
 618        })
 619        .await
 620        .unwrap();
 621    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 622    let position = snapshot.anchor_before(language::Point::new(1, 3));
 623
 624    ep_store.update(cx, |ep_store, cx| {
 625        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 626    });
 627
 628    let (request, respond_tx) = requests.predict.next().await.unwrap();
 629    let response = model_response(&request, "");
 630    let id = response.request_id.clone();
 631    respond_tx.send(response).unwrap();
 632
 633    cx.run_until_parked();
 634
 635    ep_store.update(cx, |ep_store, cx| {
 636        assert!(
 637            ep_store
 638                .prediction_at(&buffer, None, &project, cx)
 639                .is_none()
 640        );
 641    });
 642
 643    // prediction is reported as rejected
 644    let (reject_request, _) = requests.reject.next().await.unwrap();
 645
 646    assert_eq!(
 647        &reject_request.rejections,
 648        &[EditPredictionRejection {
 649            request_id: id,
 650            reason: EditPredictionRejectReason::Empty,
 651            was_shown: false
 652        }]
 653    );
 654}
 655
 656#[gpui::test]
 657async fn test_interpolated_empty(cx: &mut TestAppContext) {
 658    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 659    let fs = FakeFs::new(cx.executor());
 660    fs.insert_tree(
 661        "/root",
 662        json!({
 663            "foo.md":  "Hello!\nHow\nBye\n"
 664        }),
 665    )
 666    .await;
 667    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 668
 669    let buffer = project
 670        .update(cx, |project, cx| {
 671            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 672            project.open_buffer(path, cx)
 673        })
 674        .await
 675        .unwrap();
 676    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 677    let position = snapshot.anchor_before(language::Point::new(1, 3));
 678
 679    ep_store.update(cx, |ep_store, cx| {
 680        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 681    });
 682
 683    let (request, respond_tx) = requests.predict.next().await.unwrap();
 684
 685    buffer.update(cx, |buffer, cx| {
 686        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 687    });
 688
 689    let response = model_response(&request, SIMPLE_DIFF);
 690    let id = response.request_id.clone();
 691    respond_tx.send(response).unwrap();
 692
 693    cx.run_until_parked();
 694
 695    ep_store.update(cx, |ep_store, cx| {
 696        assert!(
 697            ep_store
 698                .prediction_at(&buffer, None, &project, cx)
 699                .is_none()
 700        );
 701    });
 702
 703    // prediction is reported as rejected
 704    let (reject_request, _) = requests.reject.next().await.unwrap();
 705
 706    assert_eq!(
 707        &reject_request.rejections,
 708        &[EditPredictionRejection {
 709            request_id: id,
 710            reason: EditPredictionRejectReason::InterpolatedEmpty,
 711            was_shown: false
 712        }]
 713    );
 714}
 715
 716const SIMPLE_DIFF: &str = indoc! { r"
 717    --- a/root/foo.md
 718    +++ b/root/foo.md
 719    @@ ... @@
 720     Hello!
 721    -How
 722    +How are you?
 723     Bye
 724"};
 725
 726#[gpui::test]
 727async fn test_replace_current(cx: &mut TestAppContext) {
 728    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 729    let fs = FakeFs::new(cx.executor());
 730    fs.insert_tree(
 731        "/root",
 732        json!({
 733            "foo.md":  "Hello!\nHow\nBye\n"
 734        }),
 735    )
 736    .await;
 737    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 738
 739    let buffer = project
 740        .update(cx, |project, cx| {
 741            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 742            project.open_buffer(path, cx)
 743        })
 744        .await
 745        .unwrap();
 746    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 747    let position = snapshot.anchor_before(language::Point::new(1, 3));
 748
 749    ep_store.update(cx, |ep_store, cx| {
 750        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 751    });
 752
 753    let (request, respond_tx) = requests.predict.next().await.unwrap();
 754    let first_response = model_response(&request, SIMPLE_DIFF);
 755    let first_id = first_response.request_id.clone();
 756    respond_tx.send(first_response).unwrap();
 757
 758    cx.run_until_parked();
 759
 760    ep_store.update(cx, |ep_store, cx| {
 761        assert_eq!(
 762            ep_store
 763                .prediction_at(&buffer, None, &project, cx)
 764                .unwrap()
 765                .id
 766                .0,
 767            first_id
 768        );
 769    });
 770
 771    // a second request is triggered
 772    ep_store.update(cx, |ep_store, cx| {
 773        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 774    });
 775
 776    let (request, respond_tx) = requests.predict.next().await.unwrap();
 777    let second_response = model_response(&request, SIMPLE_DIFF);
 778    let second_id = second_response.request_id.clone();
 779    respond_tx.send(second_response).unwrap();
 780
 781    cx.run_until_parked();
 782
 783    ep_store.update(cx, |ep_store, cx| {
 784        // second replaces first
 785        assert_eq!(
 786            ep_store
 787                .prediction_at(&buffer, None, &project, cx)
 788                .unwrap()
 789                .id
 790                .0,
 791            second_id
 792        );
 793    });
 794
 795    // first is reported as replaced
 796    let (reject_request, _) = requests.reject.next().await.unwrap();
 797
 798    assert_eq!(
 799        &reject_request.rejections,
 800        &[EditPredictionRejection {
 801            request_id: first_id,
 802            reason: EditPredictionRejectReason::Replaced,
 803            was_shown: false
 804        }]
 805    );
 806}
 807
 808#[gpui::test]
 809async fn test_current_preferred(cx: &mut TestAppContext) {
 810    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 811    let fs = FakeFs::new(cx.executor());
 812    fs.insert_tree(
 813        "/root",
 814        json!({
 815            "foo.md":  "Hello!\nHow\nBye\n"
 816        }),
 817    )
 818    .await;
 819    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 820
 821    let buffer = project
 822        .update(cx, |project, cx| {
 823            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 824            project.open_buffer(path, cx)
 825        })
 826        .await
 827        .unwrap();
 828    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 829    let position = snapshot.anchor_before(language::Point::new(1, 3));
 830
 831    ep_store.update(cx, |ep_store, cx| {
 832        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 833    });
 834
 835    let (request, respond_tx) = requests.predict.next().await.unwrap();
 836    let first_response = model_response(&request, SIMPLE_DIFF);
 837    let first_id = first_response.request_id.clone();
 838    respond_tx.send(first_response).unwrap();
 839
 840    cx.run_until_parked();
 841
 842    ep_store.update(cx, |ep_store, cx| {
 843        assert_eq!(
 844            ep_store
 845                .prediction_at(&buffer, None, &project, cx)
 846                .unwrap()
 847                .id
 848                .0,
 849            first_id
 850        );
 851    });
 852
 853    // a second request is triggered
 854    ep_store.update(cx, |ep_store, cx| {
 855        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 856    });
 857
 858    let (request, respond_tx) = requests.predict.next().await.unwrap();
 859    // worse than current prediction
 860    let second_response = model_response(
 861        &request,
 862        indoc! { r"
 863            --- a/root/foo.md
 864            +++ b/root/foo.md
 865            @@ ... @@
 866             Hello!
 867            -How
 868            +How are
 869             Bye
 870        "},
 871    );
 872    let second_id = second_response.request_id.clone();
 873    respond_tx.send(second_response).unwrap();
 874
 875    cx.run_until_parked();
 876
 877    ep_store.update(cx, |ep_store, cx| {
 878        // first is preferred over second
 879        assert_eq!(
 880            ep_store
 881                .prediction_at(&buffer, None, &project, cx)
 882                .unwrap()
 883                .id
 884                .0,
 885            first_id
 886        );
 887    });
 888
 889    // second is reported as rejected
 890    let (reject_request, _) = requests.reject.next().await.unwrap();
 891
 892    assert_eq!(
 893        &reject_request.rejections,
 894        &[EditPredictionRejection {
 895            request_id: second_id,
 896            reason: EditPredictionRejectReason::CurrentPreferred,
 897            was_shown: false
 898        }]
 899    );
 900}
 901
 902#[gpui::test]
 903async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
 904    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 905    let fs = FakeFs::new(cx.executor());
 906    fs.insert_tree(
 907        "/root",
 908        json!({
 909            "foo.md":  "Hello!\nHow\nBye\n"
 910        }),
 911    )
 912    .await;
 913    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 914
 915    let buffer = project
 916        .update(cx, |project, cx| {
 917            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 918            project.open_buffer(path, cx)
 919        })
 920        .await
 921        .unwrap();
 922    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 923    let position = snapshot.anchor_before(language::Point::new(1, 3));
 924
 925    // start two refresh tasks
 926    ep_store.update(cx, |ep_store, cx| {
 927        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 928    });
 929
 930    let (request1, respond_first) = requests.predict.next().await.unwrap();
 931
 932    ep_store.update(cx, |ep_store, cx| {
 933        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 934    });
 935
 936    let (request, respond_second) = requests.predict.next().await.unwrap();
 937
 938    // wait for throttle
 939    cx.run_until_parked();
 940
 941    // second responds first
 942    let second_response = model_response(&request, SIMPLE_DIFF);
 943    let second_id = second_response.request_id.clone();
 944    respond_second.send(second_response).unwrap();
 945
 946    cx.run_until_parked();
 947
 948    ep_store.update(cx, |ep_store, cx| {
 949        // current prediction is second
 950        assert_eq!(
 951            ep_store
 952                .prediction_at(&buffer, None, &project, cx)
 953                .unwrap()
 954                .id
 955                .0,
 956            second_id
 957        );
 958    });
 959
 960    let first_response = model_response(&request1, SIMPLE_DIFF);
 961    let first_id = first_response.request_id.clone();
 962    respond_first.send(first_response).unwrap();
 963
 964    cx.run_until_parked();
 965
 966    ep_store.update(cx, |ep_store, cx| {
 967        // current prediction is still second, since first was cancelled
 968        assert_eq!(
 969            ep_store
 970                .prediction_at(&buffer, None, &project, cx)
 971                .unwrap()
 972                .id
 973                .0,
 974            second_id
 975        );
 976    });
 977
 978    // first is reported as rejected
 979    let (reject_request, _) = requests.reject.next().await.unwrap();
 980
 981    cx.run_until_parked();
 982
 983    assert_eq!(
 984        &reject_request.rejections,
 985        &[EditPredictionRejection {
 986            request_id: first_id,
 987            reason: EditPredictionRejectReason::Canceled,
 988            was_shown: false
 989        }]
 990    );
 991}
 992
 993#[gpui::test]
 994async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
 995    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 996    let fs = FakeFs::new(cx.executor());
 997    fs.insert_tree(
 998        "/root",
 999        json!({
1000            "foo.md":  "Hello!\nHow\nBye\n"
1001        }),
1002    )
1003    .await;
1004    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1005
1006    let buffer = project
1007        .update(cx, |project, cx| {
1008            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1009            project.open_buffer(path, cx)
1010        })
1011        .await
1012        .unwrap();
1013    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1014    let position = snapshot.anchor_before(language::Point::new(1, 3));
1015
1016    // start two refresh tasks
1017    ep_store.update(cx, |ep_store, cx| {
1018        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1019    });
1020
1021    let (request1, respond_first) = requests.predict.next().await.unwrap();
1022
1023    ep_store.update(cx, |ep_store, cx| {
1024        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1025    });
1026
1027    let (request2, respond_second) = requests.predict.next().await.unwrap();
1028
1029    // wait for throttle, so requests are sent
1030    cx.run_until_parked();
1031
1032    ep_store.update(cx, |ep_store, cx| {
1033        // start a third request
1034        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1035
1036        // 2 are pending, so 2nd is cancelled
1037        assert_eq!(
1038            ep_store
1039                .get_or_init_project(&project, cx)
1040                .cancelled_predictions
1041                .iter()
1042                .copied()
1043                .collect::<Vec<_>>(),
1044            [1]
1045        );
1046    });
1047
1048    // wait for throttle
1049    cx.run_until_parked();
1050
1051    let (request3, respond_third) = requests.predict.next().await.unwrap();
1052
1053    let first_response = model_response(&request1, SIMPLE_DIFF);
1054    let first_id = first_response.request_id.clone();
1055    respond_first.send(first_response).unwrap();
1056
1057    cx.run_until_parked();
1058
1059    ep_store.update(cx, |ep_store, cx| {
1060        // current prediction is first
1061        assert_eq!(
1062            ep_store
1063                .prediction_at(&buffer, None, &project, cx)
1064                .unwrap()
1065                .id
1066                .0,
1067            first_id
1068        );
1069    });
1070
1071    let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1072    let cancelled_id = cancelled_response.request_id.clone();
1073    respond_second.send(cancelled_response).unwrap();
1074
1075    cx.run_until_parked();
1076
1077    ep_store.update(cx, |ep_store, cx| {
1078        // current prediction is still first, since second was cancelled
1079        assert_eq!(
1080            ep_store
1081                .prediction_at(&buffer, None, &project, cx)
1082                .unwrap()
1083                .id
1084                .0,
1085            first_id
1086        );
1087    });
1088
1089    let third_response = model_response(&request3, SIMPLE_DIFF);
1090    let third_response_id = third_response.request_id.clone();
1091    respond_third.send(third_response).unwrap();
1092
1093    cx.run_until_parked();
1094
1095    ep_store.update(cx, |ep_store, cx| {
1096        // third completes and replaces first
1097        assert_eq!(
1098            ep_store
1099                .prediction_at(&buffer, None, &project, cx)
1100                .unwrap()
1101                .id
1102                .0,
1103            third_response_id
1104        );
1105    });
1106
1107    // second is reported as rejected
1108    let (reject_request, _) = requests.reject.next().await.unwrap();
1109
1110    cx.run_until_parked();
1111
1112    assert_eq!(
1113        &reject_request.rejections,
1114        &[
1115            EditPredictionRejection {
1116                request_id: cancelled_id,
1117                reason: EditPredictionRejectReason::Canceled,
1118                was_shown: false
1119            },
1120            EditPredictionRejection {
1121                request_id: first_id,
1122                reason: EditPredictionRejectReason::Replaced,
1123                was_shown: false
1124            }
1125        ]
1126    );
1127}
1128
1129#[gpui::test]
1130async fn test_rejections_flushing(cx: &mut TestAppContext) {
1131    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1132
1133    ep_store.update(cx, |ep_store, cx| {
1134        ep_store.reject_prediction(
1135            EditPredictionId("test-1".into()),
1136            EditPredictionRejectReason::Discarded,
1137            false,
1138            edit_prediction_types::EditPredictionDismissReason::Ignored,
1139            cx,
1140        );
1141        ep_store.reject_prediction(
1142            EditPredictionId("test-2".into()),
1143            EditPredictionRejectReason::Canceled,
1144            true,
1145            edit_prediction_types::EditPredictionDismissReason::Ignored,
1146            cx,
1147        );
1148    });
1149
1150    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1151    cx.run_until_parked();
1152
1153    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1154    respond_tx.send(()).unwrap();
1155
1156    // batched
1157    assert_eq!(reject_request.rejections.len(), 2);
1158    assert_eq!(
1159        reject_request.rejections[0],
1160        EditPredictionRejection {
1161            request_id: "test-1".to_string(),
1162            reason: EditPredictionRejectReason::Discarded,
1163            was_shown: false
1164        }
1165    );
1166    assert_eq!(
1167        reject_request.rejections[1],
1168        EditPredictionRejection {
1169            request_id: "test-2".to_string(),
1170            reason: EditPredictionRejectReason::Canceled,
1171            was_shown: true
1172        }
1173    );
1174
1175    // Reaching batch size limit sends without debounce
1176    ep_store.update(cx, |ep_store, cx| {
1177        for i in 0..70 {
1178            ep_store.reject_prediction(
1179                EditPredictionId(format!("batch-{}", i).into()),
1180                EditPredictionRejectReason::Discarded,
1181                false,
1182                edit_prediction_types::EditPredictionDismissReason::Ignored,
1183                cx,
1184            );
1185        }
1186    });
1187
1188    // First MAX/2 items are sent immediately
1189    cx.run_until_parked();
1190    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1191    respond_tx.send(()).unwrap();
1192
1193    assert_eq!(reject_request.rejections.len(), 50);
1194    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1195    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1196
1197    // Remaining items are debounced with the next batch
1198    cx.executor().advance_clock(Duration::from_secs(15));
1199    cx.run_until_parked();
1200
1201    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1202    respond_tx.send(()).unwrap();
1203
1204    assert_eq!(reject_request.rejections.len(), 20);
1205    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1206    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1207
1208    // Request failure
1209    ep_store.update(cx, |ep_store, cx| {
1210        ep_store.reject_prediction(
1211            EditPredictionId("retry-1".into()),
1212            EditPredictionRejectReason::Discarded,
1213            false,
1214            edit_prediction_types::EditPredictionDismissReason::Ignored,
1215            cx,
1216        );
1217    });
1218
1219    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1220    cx.run_until_parked();
1221
1222    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1223    assert_eq!(reject_request.rejections.len(), 1);
1224    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1225    // Simulate failure
1226    drop(_respond_tx);
1227
1228    // Add another rejection
1229    ep_store.update(cx, |ep_store, cx| {
1230        ep_store.reject_prediction(
1231            EditPredictionId("retry-2".into()),
1232            EditPredictionRejectReason::Discarded,
1233            false,
1234            edit_prediction_types::EditPredictionDismissReason::Ignored,
1235            cx,
1236        );
1237    });
1238
1239    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1240    cx.run_until_parked();
1241
1242    // Retry should include both the failed item and the new one
1243    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1244    respond_tx.send(()).unwrap();
1245
1246    assert_eq!(reject_request.rejections.len(), 2);
1247    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1248    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1249}
1250
1251// Skipped until we start including diagnostics in prompt
1252// #[gpui::test]
1253// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1254//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1255//     let fs = FakeFs::new(cx.executor());
1256//     fs.insert_tree(
1257//         "/root",
1258//         json!({
1259//             "foo.md": "Hello!\nBye"
1260//         }),
1261//     )
1262//     .await;
1263//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1264
1265//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1266//     let diagnostic = lsp::Diagnostic {
1267//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1268//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1269//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1270//         ..Default::default()
1271//     };
1272
1273//     project.update(cx, |project, cx| {
1274//         project.lsp_store().update(cx, |lsp_store, cx| {
1275//             // Create some diagnostics
1276//             lsp_store
1277//                 .update_diagnostics(
1278//                     LanguageServerId(0),
1279//                     lsp::PublishDiagnosticsParams {
1280//                         uri: path_to_buffer_uri.clone(),
1281//                         diagnostics: vec![diagnostic],
1282//                         version: None,
1283//                     },
1284//                     None,
1285//                     language::DiagnosticSourceKind::Pushed,
1286//                     &[],
1287//                     cx,
1288//                 )
1289//                 .unwrap();
1290//         });
1291//     });
1292
1293//     let buffer = project
1294//         .update(cx, |project, cx| {
1295//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1296//             project.open_buffer(path, cx)
1297//         })
1298//         .await
1299//         .unwrap();
1300
1301//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1302//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1303
1304//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1305//         ep_store.request_prediction(&project, &buffer, position, cx)
1306//     });
1307
1308//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1309
1310//     assert_eq!(request.diagnostic_groups.len(), 1);
1311//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1312//         .unwrap();
1313//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1314//     assert_eq!(
1315//         value,
1316//         json!({
1317//             "entries": [{
1318//                 "range": {
1319//                     "start": 8,
1320//                     "end": 10
1321//                 },
1322//                 "diagnostic": {
1323//                     "source": null,
1324//                     "code": null,
1325//                     "code_description": null,
1326//                     "severity": 1,
1327//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1328//                     "markdown": null,
1329//                     "group_id": 0,
1330//                     "is_primary": true,
1331//                     "is_disk_based": false,
1332//                     "is_unnecessary": false,
1333//                     "source_kind": "Pushed",
1334//                     "data": null,
1335//                     "underline": true
1336//                 }
1337//             }],
1338//             "primary_ix": 0
1339//         })
1340//     );
1341// }
1342
1343// Generate a model response that would apply the given diff to the active file.
1344fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1345    let excerpt =
1346        request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1347    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1348
1349    PredictEditsV3Response {
1350        request_id: Uuid::new_v4().to_string(),
1351        output: new_excerpt,
1352    }
1353}
1354
1355fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1356    zeta_prompt::format_zeta_prompt(&request.input, request.prompt_version)
1357}
1358
1359struct RequestChannels {
1360    predict: mpsc::UnboundedReceiver<(
1361        PredictEditsV3Request,
1362        oneshot::Sender<PredictEditsV3Response>,
1363    )>,
1364    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1365}
1366
1367fn init_test_with_fake_client(
1368    cx: &mut TestAppContext,
1369) -> (Entity<EditPredictionStore>, RequestChannels) {
1370    cx.update(move |cx| {
1371        let settings_store = SettingsStore::test(cx);
1372        cx.set_global(settings_store);
1373        zlog::init_test();
1374
1375        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1376        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1377
1378        let http_client = FakeHttpClient::create({
1379            move |req| {
1380                let uri = req.uri().path().to_string();
1381                let mut body = req.into_body();
1382                let predict_req_tx = predict_req_tx.clone();
1383                let reject_req_tx = reject_req_tx.clone();
1384                async move {
1385                    let resp = match uri.as_str() {
1386                        "/client/llm_tokens" => serde_json::to_string(&json!({
1387                            "token": "test"
1388                        }))
1389                        .unwrap(),
1390                        "/predict_edits/v3" => {
1391                            let mut buf = Vec::new();
1392                            body.read_to_end(&mut buf).await.ok();
1393                            let req = serde_json::from_slice(&buf).unwrap();
1394
1395                            let (res_tx, res_rx) = oneshot::channel();
1396                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1397                            serde_json::to_string(&res_rx.await?).unwrap()
1398                        }
1399                        "/predict_edits/reject" => {
1400                            let mut buf = Vec::new();
1401                            body.read_to_end(&mut buf).await.ok();
1402                            let req = serde_json::from_slice(&buf).unwrap();
1403
1404                            let (res_tx, res_rx) = oneshot::channel();
1405                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1406                            serde_json::to_string(&res_rx.await?).unwrap()
1407                        }
1408                        _ => {
1409                            panic!("Unexpected path: {}", uri)
1410                        }
1411                    };
1412
1413                    Ok(Response::builder().body(resp.into()).unwrap())
1414                }
1415            }
1416        });
1417
1418        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1419        client.cloud_client().set_credentials(1, "test".into());
1420
1421        language_model::init(client.clone(), cx);
1422
1423        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1424        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1425
1426        (
1427            ep_store,
1428            RequestChannels {
1429                predict: predict_req_rx,
1430                reject: reject_req_rx,
1431            },
1432        )
1433    })
1434}
1435
1436const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1437
1438#[gpui::test]
1439async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1440    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1441    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1442        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1443    });
1444
1445    let edit_preview = cx
1446        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1447        .await;
1448
1449    let prediction = EditPrediction {
1450        edits,
1451        cursor_position: None,
1452        edit_preview,
1453        buffer: buffer.clone(),
1454        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1455        id: EditPredictionId("the-id".into()),
1456        inputs: ZetaPromptInput {
1457            events: Default::default(),
1458            related_files: Default::default(),
1459            cursor_path: Path::new("").into(),
1460            cursor_excerpt: "".into(),
1461            editable_range_in_excerpt: 0..0,
1462            cursor_offset_in_excerpt: 0,
1463            excerpt_start_row: None,
1464        },
1465        buffer_snapshotted_at: Instant::now(),
1466        response_received_at: Instant::now(),
1467    };
1468
1469    cx.update(|cx| {
1470        assert_eq!(
1471            from_completion_edits(
1472                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1473                &buffer,
1474                cx
1475            ),
1476            vec![(2..5, "REM".into()), (9..11, "".into())]
1477        );
1478
1479        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1480        assert_eq!(
1481            from_completion_edits(
1482                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1483                &buffer,
1484                cx
1485            ),
1486            vec![(2..2, "REM".into()), (6..8, "".into())]
1487        );
1488
1489        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1490        assert_eq!(
1491            from_completion_edits(
1492                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1493                &buffer,
1494                cx
1495            ),
1496            vec![(2..5, "REM".into()), (9..11, "".into())]
1497        );
1498
1499        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1500        assert_eq!(
1501            from_completion_edits(
1502                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1503                &buffer,
1504                cx
1505            ),
1506            vec![(3..3, "EM".into()), (7..9, "".into())]
1507        );
1508
1509        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1510        assert_eq!(
1511            from_completion_edits(
1512                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1513                &buffer,
1514                cx
1515            ),
1516            vec![(4..4, "M".into()), (8..10, "".into())]
1517        );
1518
1519        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1520        assert_eq!(
1521            from_completion_edits(
1522                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1523                &buffer,
1524                cx
1525            ),
1526            vec![(9..11, "".into())]
1527        );
1528
1529        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1530        assert_eq!(
1531            from_completion_edits(
1532                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1533                &buffer,
1534                cx
1535            ),
1536            vec![(4..4, "M".into()), (8..10, "".into())]
1537        );
1538
1539        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1540        assert_eq!(
1541            from_completion_edits(
1542                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1543                &buffer,
1544                cx
1545            ),
1546            vec![(4..4, "M".into())]
1547        );
1548
1549        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1550        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1551    })
1552}
1553
1554#[gpui::test]
1555async fn test_clean_up_diff(cx: &mut TestAppContext) {
1556    init_test(cx);
1557
1558    assert_eq!(
1559        apply_edit_prediction(
1560            indoc! {"
1561                    fn main() {
1562                        let word_1 = \"lorem\";
1563                        let range = word.len()..word.len();
1564                    }
1565                "},
1566            indoc! {"
1567                    <|editable_region_start|>
1568                    fn main() {
1569                        let word_1 = \"lorem\";
1570                        let range = word_1.len()..word_1.len();
1571                    }
1572
1573                    <|editable_region_end|>
1574                "},
1575            cx,
1576        )
1577        .await,
1578        indoc! {"
1579                fn main() {
1580                    let word_1 = \"lorem\";
1581                    let range = word_1.len()..word_1.len();
1582                }
1583            "},
1584    );
1585
1586    assert_eq!(
1587        apply_edit_prediction(
1588            indoc! {"
1589                    fn main() {
1590                        let story = \"the quick\"
1591                    }
1592                "},
1593            indoc! {"
1594                    <|editable_region_start|>
1595                    fn main() {
1596                        let story = \"the quick brown fox jumps over the lazy dog\";
1597                    }
1598
1599                    <|editable_region_end|>
1600                "},
1601            cx,
1602        )
1603        .await,
1604        indoc! {"
1605                fn main() {
1606                    let story = \"the quick brown fox jumps over the lazy dog\";
1607                }
1608            "},
1609    );
1610}
1611
1612#[gpui::test]
1613async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1614    init_test(cx);
1615
1616    let buffer_content = "lorem\n";
1617    let completion_response = indoc! {"
1618            ```animals.js
1619            <|start_of_file|>
1620            <|editable_region_start|>
1621            lorem
1622            ipsum
1623            <|editable_region_end|>
1624            ```"};
1625
1626    assert_eq!(
1627        apply_edit_prediction(buffer_content, completion_response, cx).await,
1628        "lorem\nipsum"
1629    );
1630}
1631
1632#[gpui::test]
1633async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1634    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1635    // When the buffer ends without a trailing newline, but the model returns output
1636    // with a trailing newline, zeta2 should normalize both sides before diffing
1637    // so no spurious newline is inserted.
1638    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1639    let fs = FakeFs::new(cx.executor());
1640
1641    // Single line buffer with no trailing newline
1642    fs.insert_tree(
1643        "/root",
1644        json!({
1645            "foo.txt": "hello"
1646        }),
1647    )
1648    .await;
1649    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1650
1651    let buffer = project
1652        .update(cx, |project, cx| {
1653            let path = project
1654                .find_project_path(path!("root/foo.txt"), cx)
1655                .unwrap();
1656            project.open_buffer(path, cx)
1657        })
1658        .await
1659        .unwrap();
1660
1661    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1662    let position = snapshot.anchor_before(language::Point::new(0, 5));
1663
1664    ep_store.update(cx, |ep_store, cx| {
1665        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1666    });
1667
1668    let (_request, respond_tx) = requests.predict.next().await.unwrap();
1669
1670    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1671    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1672    let response = PredictEditsV3Response {
1673        request_id: Uuid::new_v4().to_string(),
1674        output: "hello world\n".to_string(),
1675    };
1676    respond_tx.send(response).unwrap();
1677
1678    cx.run_until_parked();
1679
1680    // The prediction should insert " world" without adding a newline
1681    ep_store.update(cx, |ep_store, cx| {
1682        let prediction = ep_store
1683            .prediction_at(&buffer, None, &project, cx)
1684            .expect("should have prediction");
1685        let edits: Vec<_> = prediction
1686            .edits
1687            .iter()
1688            .map(|(range, text)| {
1689                let snapshot = buffer.read(cx).snapshot();
1690                (range.to_offset(&snapshot), text.clone())
1691            })
1692            .collect();
1693        assert_eq!(edits, vec![(5..5, " world".into())]);
1694    });
1695}
1696
1697#[gpui::test]
1698async fn test_can_collect_data(cx: &mut TestAppContext) {
1699    init_test(cx);
1700
1701    let fs = project::FakeFs::new(cx.executor());
1702    fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1703        .await;
1704
1705    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1706    let buffer = project
1707        .update(cx, |project, cx| {
1708            project.open_local_buffer(path!("/project/src/main.rs"), cx)
1709        })
1710        .await
1711        .unwrap();
1712
1713    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1714    ep_store.update(cx, |ep_store, _cx| {
1715        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1716    });
1717
1718    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1719    assert_eq!(
1720        captured_request.lock().clone().unwrap().can_collect_data,
1721        true
1722    );
1723
1724    ep_store.update(cx, |ep_store, _cx| {
1725        ep_store.data_collection_choice = DataCollectionChoice::Disabled
1726    });
1727
1728    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1729    assert_eq!(
1730        captured_request.lock().clone().unwrap().can_collect_data,
1731        false
1732    );
1733}
1734
1735#[gpui::test]
1736async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1737    init_test(cx);
1738
1739    let fs = project::FakeFs::new(cx.executor());
1740    let project = Project::test(fs.clone(), [], cx).await;
1741
1742    let buffer = cx.new(|_cx| {
1743        Buffer::remote(
1744            language::BufferId::new(1).unwrap(),
1745            ReplicaId::new(1),
1746            language::Capability::ReadWrite,
1747            "fn main() {\n    println!(\"Hello\");\n}",
1748        )
1749    });
1750
1751    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1752    ep_store.update(cx, |ep_store, _cx| {
1753        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1754    });
1755
1756    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1757    assert_eq!(
1758        captured_request.lock().clone().unwrap().can_collect_data,
1759        false
1760    );
1761}
1762
1763#[gpui::test]
1764async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1765    init_test(cx);
1766
1767    let fs = project::FakeFs::new(cx.executor());
1768    fs.insert_tree(
1769        path!("/project"),
1770        json!({
1771            "LICENSE": BSD_0_TXT,
1772            ".env": "SECRET_KEY=secret"
1773        }),
1774    )
1775    .await;
1776
1777    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1778    let buffer = project
1779        .update(cx, |project, cx| {
1780            project.open_local_buffer("/project/.env", cx)
1781        })
1782        .await
1783        .unwrap();
1784
1785    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1786    ep_store.update(cx, |ep_store, _cx| {
1787        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1788    });
1789
1790    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1791    assert_eq!(
1792        captured_request.lock().clone().unwrap().can_collect_data,
1793        false
1794    );
1795}
1796
1797#[gpui::test]
1798async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1799    init_test(cx);
1800
1801    let fs = project::FakeFs::new(cx.executor());
1802    let project = Project::test(fs.clone(), [], cx).await;
1803    let buffer = cx.new(|cx| Buffer::local("", cx));
1804
1805    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1806    ep_store.update(cx, |ep_store, _cx| {
1807        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1808    });
1809
1810    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1811    assert_eq!(
1812        captured_request.lock().clone().unwrap().can_collect_data,
1813        false
1814    );
1815}
1816
1817#[gpui::test]
1818async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1819    init_test(cx);
1820
1821    let fs = project::FakeFs::new(cx.executor());
1822    fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1823        .await;
1824
1825    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1826    let buffer = project
1827        .update(cx, |project, cx| {
1828            project.open_local_buffer("/project/main.rs", cx)
1829        })
1830        .await
1831        .unwrap();
1832
1833    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1834    ep_store.update(cx, |ep_store, _cx| {
1835        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1836    });
1837
1838    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1839    assert_eq!(
1840        captured_request.lock().clone().unwrap().can_collect_data,
1841        false
1842    );
1843}
1844
1845#[gpui::test]
1846async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1847    init_test(cx);
1848
1849    let fs = project::FakeFs::new(cx.executor());
1850    fs.insert_tree(
1851        path!("/open_source_worktree"),
1852        json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1853    )
1854    .await;
1855    fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1856        .await;
1857
1858    let project = Project::test(
1859        fs.clone(),
1860        [
1861            path!("/open_source_worktree").as_ref(),
1862            path!("/closed_source_worktree").as_ref(),
1863        ],
1864        cx,
1865    )
1866    .await;
1867    let buffer = project
1868        .update(cx, |project, cx| {
1869            project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1870        })
1871        .await
1872        .unwrap();
1873
1874    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1875    ep_store.update(cx, |ep_store, _cx| {
1876        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1877    });
1878
1879    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1880    assert_eq!(
1881        captured_request.lock().clone().unwrap().can_collect_data,
1882        true
1883    );
1884
1885    let closed_source_file = project
1886        .update(cx, |project, cx| {
1887            let worktree2 = project
1888                .worktree_for_root_name("closed_source_worktree", cx)
1889                .unwrap();
1890            worktree2.update(cx, |worktree2, cx| {
1891                worktree2.load_file(rel_path("main.rs"), cx)
1892            })
1893        })
1894        .await
1895        .unwrap()
1896        .file;
1897
1898    buffer.update(cx, |buffer, cx| {
1899        buffer.file_updated(closed_source_file, cx);
1900    });
1901
1902    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1903    assert_eq!(
1904        captured_request.lock().clone().unwrap().can_collect_data,
1905        false
1906    );
1907}
1908
1909#[gpui::test]
1910async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1911    init_test(cx);
1912
1913    let fs = project::FakeFs::new(cx.executor());
1914    fs.insert_tree(
1915        path!("/worktree1"),
1916        json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1917    )
1918    .await;
1919    fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1920        .await;
1921
1922    let project = Project::test(
1923        fs.clone(),
1924        [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1925        cx,
1926    )
1927    .await;
1928    let buffer = project
1929        .update(cx, |project, cx| {
1930            project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1931        })
1932        .await
1933        .unwrap();
1934    let private_buffer = project
1935        .update(cx, |project, cx| {
1936            project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1937        })
1938        .await
1939        .unwrap();
1940
1941    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1942    ep_store.update(cx, |ep_store, _cx| {
1943        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1944    });
1945
1946    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1947    assert_eq!(
1948        captured_request.lock().clone().unwrap().can_collect_data,
1949        true
1950    );
1951
1952    // this has a side effect of registering the buffer to watch for edits
1953    run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1954    assert_eq!(
1955        captured_request.lock().clone().unwrap().can_collect_data,
1956        false
1957    );
1958
1959    private_buffer.update(cx, |private_buffer, cx| {
1960        private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1961    });
1962
1963    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1964    assert_eq!(
1965        captured_request.lock().clone().unwrap().can_collect_data,
1966        false
1967    );
1968
1969    // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1970    // included
1971    buffer.update(cx, |buffer, cx| {
1972        buffer.edit(
1973            [(
1974                0..0,
1975                " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1976            )],
1977            None,
1978            cx,
1979        );
1980    });
1981
1982    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1983    assert_eq!(
1984        captured_request.lock().clone().unwrap().can_collect_data,
1985        true
1986    );
1987}
1988
1989fn init_test(cx: &mut TestAppContext) {
1990    cx.update(|cx| {
1991        let settings_store = SettingsStore::test(cx);
1992        cx.set_global(settings_store);
1993    });
1994}
1995
1996async fn apply_edit_prediction(
1997    buffer_content: &str,
1998    completion_response: &str,
1999    cx: &mut TestAppContext,
2000) -> String {
2001    let fs = project::FakeFs::new(cx.executor());
2002    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2003    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
2004    let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
2005    *response.lock() = completion_response.to_string();
2006    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
2007    buffer.update(cx, |buffer, cx| {
2008        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
2009    });
2010    buffer.read_with(cx, |buffer, _| buffer.text())
2011}
2012
2013async fn run_edit_prediction(
2014    buffer: &Entity<Buffer>,
2015    project: &Entity<Project>,
2016    ep_store: &Entity<EditPredictionStore>,
2017    cx: &mut TestAppContext,
2018) -> EditPrediction {
2019    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2020    ep_store.update(cx, |ep_store, cx| {
2021        ep_store.register_buffer(buffer, &project, cx)
2022    });
2023    cx.background_executor.run_until_parked();
2024    let prediction_task = ep_store.update(cx, |ep_store, cx| {
2025        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2026    });
2027    prediction_task.await.unwrap().unwrap().prediction.unwrap()
2028}
2029
2030async fn make_test_ep_store(
2031    project: &Entity<Project>,
2032    cx: &mut TestAppContext,
2033) -> (
2034    Entity<EditPredictionStore>,
2035    Arc<Mutex<Option<PredictEditsBody>>>,
2036    Arc<Mutex<String>>,
2037) {
2038    let default_response = indoc! {"
2039            ```main.rs
2040            <|start_of_file|>
2041            <|editable_region_start|>
2042            hello world
2043            <|editable_region_end|>
2044            ```"
2045    };
2046    let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2047    let completion_response: Arc<Mutex<String>> =
2048        Arc::new(Mutex::new(default_response.to_string()));
2049    let http_client = FakeHttpClient::create({
2050        let captured_request = captured_request.clone();
2051        let completion_response = completion_response.clone();
2052        let mut next_request_id = 0;
2053        move |req| {
2054            let captured_request = captured_request.clone();
2055            let completion_response = completion_response.clone();
2056            async move {
2057                match (req.method(), req.uri().path()) {
2058                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2059                        .status(200)
2060                        .body(
2061                            serde_json::to_string(&CreateLlmTokenResponse {
2062                                token: LlmToken("the-llm-token".to_string()),
2063                            })
2064                            .unwrap()
2065                            .into(),
2066                        )
2067                        .unwrap()),
2068                    (&Method::POST, "/predict_edits/v2") => {
2069                        let mut request_body = String::new();
2070                        req.into_body().read_to_string(&mut request_body).await?;
2071                        *captured_request.lock() =
2072                            Some(serde_json::from_str(&request_body).unwrap());
2073                        next_request_id += 1;
2074                        Ok(http_client::Response::builder()
2075                            .status(200)
2076                            .body(
2077                                serde_json::to_string(&PredictEditsResponse {
2078                                    request_id: format!("request-{next_request_id}"),
2079                                    output_excerpt: completion_response.lock().clone(),
2080                                })
2081                                .unwrap()
2082                                .into(),
2083                            )
2084                            .unwrap())
2085                    }
2086                    _ => Ok(http_client::Response::builder()
2087                        .status(404)
2088                        .body("Not Found".into())
2089                        .unwrap()),
2090                }
2091            }
2092        }
2093    });
2094
2095    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2096    cx.update(|cx| {
2097        RefreshLlmTokenListener::register(client.clone(), cx);
2098    });
2099    let _server = FakeServer::for_client(42, &client, cx).await;
2100
2101    let ep_store = cx.new(|cx| {
2102        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2103        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2104
2105        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2106        for worktree in worktrees {
2107            let worktree_id = worktree.read(cx).id();
2108            ep_store
2109                .get_or_init_project(project, cx)
2110                .license_detection_watchers
2111                .entry(worktree_id)
2112                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2113        }
2114
2115        ep_store
2116    });
2117
2118    (ep_store, captured_request, completion_response)
2119}
2120
2121fn to_completion_edits(
2122    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2123    buffer: &Entity<Buffer>,
2124    cx: &App,
2125) -> Vec<(Range<Anchor>, Arc<str>)> {
2126    let buffer = buffer.read(cx);
2127    iterator
2128        .into_iter()
2129        .map(|(range, text)| {
2130            (
2131                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2132                text,
2133            )
2134        })
2135        .collect()
2136}
2137
2138fn from_completion_edits(
2139    editor_edits: &[(Range<Anchor>, Arc<str>)],
2140    buffer: &Entity<Buffer>,
2141    cx: &App,
2142) -> Vec<(Range<usize>, Arc<str>)> {
2143    let buffer = buffer.read(cx);
2144    editor_edits
2145        .iter()
2146        .map(|(range, text)| {
2147            (
2148                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2149                text.clone(),
2150            )
2151        })
2152        .collect()
2153}
2154
2155#[gpui::test]
2156async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2157    init_test(cx);
2158
2159    let fs = FakeFs::new(cx.executor());
2160    fs.insert_tree(
2161        "/project",
2162        serde_json::json!({
2163            "main.rs": "fn main() {\n    \n}\n"
2164        }),
2165    )
2166    .await;
2167
2168    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2169
2170    let http_client = FakeHttpClient::create(|_req| async move {
2171        Ok(gpui::http_client::Response::builder()
2172            .status(401)
2173            .body("Unauthorized".into())
2174            .unwrap())
2175    });
2176
2177    let client =
2178        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2179    cx.update(|cx| {
2180        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2181    });
2182
2183    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2184
2185    let buffer = project
2186        .update(cx, |project, cx| {
2187            let path = project
2188                .find_project_path(path!("/project/main.rs"), cx)
2189                .unwrap();
2190            project.open_buffer(path, cx)
2191        })
2192        .await
2193        .unwrap();
2194
2195    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2196    ep_store.update(cx, |ep_store, cx| {
2197        ep_store.register_buffer(&buffer, &project, cx)
2198    });
2199    cx.background_executor.run_until_parked();
2200
2201    let completion_task = ep_store.update(cx, |ep_store, cx| {
2202        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2203        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2204    });
2205
2206    let result = completion_task.await;
2207    assert!(
2208        result.is_err(),
2209        "Without authentication and without custom URL, prediction should fail"
2210    );
2211}
2212
2213#[gpui::test]
2214async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2215    init_test(cx);
2216
2217    let fs = FakeFs::new(cx.executor());
2218    fs.insert_tree(
2219        "/project",
2220        serde_json::json!({
2221            "main.rs": "fn main() {\n    \n}\n"
2222        }),
2223    )
2224    .await;
2225
2226    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2227
2228    let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2229    let predict_called_clone = predict_called.clone();
2230
2231    let http_client = FakeHttpClient::create({
2232        move |req| {
2233            let uri = req.uri().path().to_string();
2234            let predict_called = predict_called_clone.clone();
2235            async move {
2236                if uri.contains("predict") {
2237                    predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2238                    Ok(gpui::http_client::Response::builder()
2239                        .body(
2240                            serde_json::to_string(&open_ai::Response {
2241                                id: "test-123".to_string(),
2242                                object: "chat.completion".to_string(),
2243                                created: 0,
2244                                model: "test".to_string(),
2245                                usage: open_ai::Usage {
2246                                    prompt_tokens: 0,
2247                                    completion_tokens: 0,
2248                                    total_tokens: 0,
2249                                },
2250                                choices: vec![open_ai::Choice {
2251                                    index: 0,
2252                                    message: open_ai::RequestMessage::Assistant {
2253                                        content: Some(open_ai::MessageContent::Plain(
2254                                            indoc! {"
2255                                                ```main.rs
2256                                                <|start_of_file|>
2257                                                <|editable_region_start|>
2258                                                fn main() {
2259                                                    println!(\"Hello, world!\");
2260                                                }
2261                                                <|editable_region_end|>
2262                                                ```
2263                                            "}
2264                                            .to_string(),
2265                                        )),
2266                                        tool_calls: vec![],
2267                                    },
2268                                    finish_reason: Some("stop".to_string()),
2269                                }],
2270                            })
2271                            .unwrap()
2272                            .into(),
2273                        )
2274                        .unwrap())
2275                } else {
2276                    Ok(gpui::http_client::Response::builder()
2277                        .status(401)
2278                        .body("Unauthorized".into())
2279                        .unwrap())
2280                }
2281            }
2282        }
2283    });
2284
2285    let client =
2286        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2287    cx.update(|cx| {
2288        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2289    });
2290
2291    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2292
2293    let buffer = project
2294        .update(cx, |project, cx| {
2295            let path = project
2296                .find_project_path(path!("/project/main.rs"), cx)
2297                .unwrap();
2298            project.open_buffer(path, cx)
2299        })
2300        .await
2301        .unwrap();
2302
2303    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2304    ep_store.update(cx, |ep_store, cx| {
2305        ep_store.register_buffer(&buffer, &project, cx)
2306    });
2307    cx.background_executor.run_until_parked();
2308
2309    let completion_task = ep_store.update(cx, |ep_store, cx| {
2310        ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2311        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2312        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2313    });
2314
2315    let _ = completion_task.await;
2316
2317    assert!(
2318        predict_called.load(std::sync::atomic::Ordering::SeqCst),
2319        "With custom URL, predict endpoint should be called even without authentication"
2320    );
2321}
2322
2323#[gpui::test]
2324fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2325    let buffer = cx.new(|cx| {
2326        Buffer::local(
2327            indoc! {"
2328                zero
2329                one
2330                two
2331                three
2332                four
2333                five
2334                six
2335                seven
2336                eight
2337                nine
2338                ten
2339                eleven
2340                twelve
2341                thirteen
2342                fourteen
2343                fifteen
2344                sixteen
2345                seventeen
2346                eighteen
2347                nineteen
2348                twenty
2349                twenty-one
2350                twenty-two
2351                twenty-three
2352                twenty-four
2353            "},
2354            cx,
2355        )
2356    });
2357
2358    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2359
2360    buffer.update(cx, |buffer, cx| {
2361        let point = Point::new(12, 0);
2362        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2363        let point = Point::new(8, 0);
2364        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2365    });
2366
2367    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2368
2369    let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2370
2371    assert_eq!(
2372        diff,
2373        indoc! {"
2374            @@ -6,10 +6,12 @@
2375             five
2376             six
2377             seven
2378            +FIRST INSERTION
2379             eight
2380             nine
2381             ten
2382             eleven
2383            +SECOND INSERTION
2384             twelve
2385             thirteen
2386             fourteen
2387            "}
2388    );
2389}
2390
2391#[ctor::ctor]
2392fn init_logger() {
2393    zlog::init_test();
2394}