edit_prediction_tests.rs

   1use super::*;
   2use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
   3use client::{UserStore, test::FakeServer};
   4use clock::{FakeSystemClock, ReplicaId};
   5use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
   6use cloud_llm_client::{
   7    EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
   8    RejectEditPredictionsBody,
   9    predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
  10};
  11use futures::{
  12    AsyncReadExt, StreamExt,
  13    channel::{mpsc, oneshot},
  14};
  15use gpui::App;
  16use gpui::{
  17    Entity, TestAppContext,
  18    http_client::{FakeHttpClient, Response},
  19};
  20use indoc::indoc;
  21use language::{Buffer, Point};
  22use lsp::LanguageServerId;
  23use parking_lot::Mutex;
  24use pretty_assertions::{assert_eq, assert_matches};
  25use project::{FakeFs, Project};
  26use serde_json::json;
  27use settings::SettingsStore;
  28use std::{path::Path, sync::Arc, time::Duration};
  29use util::{path, rel_path::rel_path};
  30use uuid::Uuid;
  31use zeta_prompt::ZetaPromptInput;
  32
  33use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
  34
  35#[gpui::test]
  36async fn test_current_state(cx: &mut TestAppContext) {
  37    let (ep_store, mut requests) = init_test_with_fake_client(cx);
  38    let fs = FakeFs::new(cx.executor());
  39    fs.insert_tree(
  40        "/root",
  41        json!({
  42            "1.txt": "Hello!\nHow\nBye\n",
  43            "2.txt": "Hola!\nComo\nAdios\n"
  44        }),
  45    )
  46    .await;
  47    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
  48
  49    let buffer1 = project
  50        .update(cx, |project, cx| {
  51            let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
  52            project.set_active_path(Some(path.clone()), cx);
  53            project.open_buffer(path, cx)
  54        })
  55        .await
  56        .unwrap();
  57    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
  58    let position = snapshot1.anchor_before(language::Point::new(1, 3));
  59
  60    ep_store.update(cx, |ep_store, cx| {
  61        ep_store.register_project(&project, cx);
  62        ep_store.register_buffer(&buffer1, &project, cx);
  63    });
  64
  65    // Prediction for current file
  66
  67    ep_store.update(cx, |ep_store, cx| {
  68        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
  69    });
  70    let (request, respond_tx) = requests.predict.next().await.unwrap();
  71
  72    respond_tx
  73        .send(model_response(
  74            &request,
  75            indoc! {r"
  76                --- a/root/1.txt
  77                +++ b/root/1.txt
  78                @@ ... @@
  79                 Hello!
  80                -How
  81                +How are you?
  82                 Bye
  83            "},
  84        ))
  85        .unwrap();
  86
  87    cx.run_until_parked();
  88
  89    ep_store.update(cx, |ep_store, cx| {
  90        let prediction = ep_store
  91            .prediction_at(&buffer1, None, &project, cx)
  92            .unwrap();
  93        assert_matches!(prediction, BufferEditPrediction::Local { .. });
  94    });
  95
  96    ep_store.update(cx, |ep_store, cx| {
  97        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
  98    });
  99
 100    // Prediction for diagnostic in another file
 101
 102    let diagnostic = lsp::Diagnostic {
 103        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
 104        severity: Some(lsp::DiagnosticSeverity::ERROR),
 105        message: "Sentence is incomplete".to_string(),
 106        ..Default::default()
 107    };
 108
 109    project.update(cx, |project, cx| {
 110        project.lsp_store().update(cx, |lsp_store, cx| {
 111            lsp_store
 112                .update_diagnostics(
 113                    LanguageServerId(0),
 114                    lsp::PublishDiagnosticsParams {
 115                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
 116                        diagnostics: vec![diagnostic],
 117                        version: None,
 118                    },
 119                    None,
 120                    language::DiagnosticSourceKind::Pushed,
 121                    &[],
 122                    cx,
 123                )
 124                .unwrap();
 125        });
 126    });
 127
 128    let (request, respond_tx) = requests.predict.next().await.unwrap();
 129    respond_tx
 130        .send(model_response(
 131            &request,
 132            indoc! {r#"
 133                --- a/root/2.txt
 134                +++ b/root/2.txt
 135                @@ ... @@
 136                 Hola!
 137                -Como
 138                +Como estas?
 139                 Adios
 140            "#},
 141        ))
 142        .unwrap();
 143    cx.run_until_parked();
 144
 145    ep_store.update(cx, |ep_store, cx| {
 146        let prediction = ep_store
 147            .prediction_at(&buffer1, None, &project, cx)
 148            .unwrap();
 149        assert_matches!(
 150            prediction,
 151            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
 152        );
 153    });
 154
 155    let buffer2 = project
 156        .update(cx, |project, cx| {
 157            let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
 158            project.open_buffer(path, cx)
 159        })
 160        .await
 161        .unwrap();
 162
 163    ep_store.update(cx, |ep_store, cx| {
 164        let prediction = ep_store
 165            .prediction_at(&buffer2, None, &project, cx)
 166            .unwrap();
 167        assert_matches!(prediction, BufferEditPrediction::Local { .. });
 168    });
 169}
 170
 171#[gpui::test]
 172async fn test_simple_request(cx: &mut TestAppContext) {
 173    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 174    let fs = FakeFs::new(cx.executor());
 175    fs.insert_tree(
 176        "/root",
 177        json!({
 178            "foo.md":  "Hello!\nHow\nBye\n"
 179        }),
 180    )
 181    .await;
 182    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 183
 184    let buffer = project
 185        .update(cx, |project, cx| {
 186            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 187            project.open_buffer(path, cx)
 188        })
 189        .await
 190        .unwrap();
 191    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 192    let position = snapshot.anchor_before(language::Point::new(1, 3));
 193
 194    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 195        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 196    });
 197
 198    let (request, respond_tx) = requests.predict.next().await.unwrap();
 199
 200    // TODO Put back when we have a structured request again
 201    // assert_eq!(
 202    //     request.excerpt_path.as_ref(),
 203    //     Path::new(path!("root/foo.md"))
 204    // );
 205    // assert_eq!(
 206    //     request.cursor_point,
 207    //     Point {
 208    //         line: Line(1),
 209    //         column: 3
 210    //     }
 211    // );
 212
 213    respond_tx
 214        .send(model_response(
 215            &request,
 216            indoc! { r"
 217                --- a/root/foo.md
 218                +++ b/root/foo.md
 219                @@ ... @@
 220                 Hello!
 221                -How
 222                +How are you?
 223                 Bye
 224            "},
 225        ))
 226        .unwrap();
 227
 228    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 229
 230    assert_eq!(prediction.edits.len(), 1);
 231    assert_eq!(
 232        prediction.edits[0].0.to_point(&snapshot).start,
 233        language::Point::new(1, 3)
 234    );
 235    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 236}
 237
 238#[gpui::test]
 239async fn test_request_events(cx: &mut TestAppContext) {
 240    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 241    let fs = FakeFs::new(cx.executor());
 242    fs.insert_tree(
 243        "/root",
 244        json!({
 245            "foo.md": "Hello!\n\nBye\n"
 246        }),
 247    )
 248    .await;
 249    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 250
 251    let buffer = project
 252        .update(cx, |project, cx| {
 253            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 254            project.open_buffer(path, cx)
 255        })
 256        .await
 257        .unwrap();
 258
 259    ep_store.update(cx, |ep_store, cx| {
 260        ep_store.register_buffer(&buffer, &project, cx);
 261    });
 262
 263    buffer.update(cx, |buffer, cx| {
 264        buffer.edit(vec![(7..7, "How")], None, cx);
 265    });
 266
 267    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 268    let position = snapshot.anchor_before(language::Point::new(1, 3));
 269
 270    let prediction_task = ep_store.update(cx, |ep_store, cx| {
 271        ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
 272    });
 273
 274    let (request, respond_tx) = requests.predict.next().await.unwrap();
 275
 276    let prompt = prompt_from_request(&request);
 277    assert!(
 278        prompt.contains(indoc! {"
 279        --- a/root/foo.md
 280        +++ b/root/foo.md
 281        @@ -1,3 +1,3 @@
 282         Hello!
 283        -
 284        +How
 285         Bye
 286    "}),
 287        "{prompt}"
 288    );
 289
 290    respond_tx
 291        .send(model_response(
 292            &request,
 293            indoc! {r#"
 294                --- a/root/foo.md
 295                +++ b/root/foo.md
 296                @@ ... @@
 297                 Hello!
 298                -How
 299                +How are you?
 300                 Bye
 301        "#},
 302        ))
 303        .unwrap();
 304
 305    let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
 306
 307    assert_eq!(prediction.edits.len(), 1);
 308    assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
 309}
 310
 311#[gpui::test]
 312async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) {
 313    let (ep_store, _requests) = init_test_with_fake_client(cx);
 314    let fs = FakeFs::new(cx.executor());
 315    fs.insert_tree(
 316        "/root",
 317        json!({
 318            "foo.md": "Hello!\n\nBye\n"
 319        }),
 320    )
 321    .await;
 322    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 323
 324    let buffer = project
 325        .update(cx, |project, cx| {
 326            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 327            project.open_buffer(path, cx)
 328        })
 329        .await
 330        .unwrap();
 331
 332    ep_store.update(cx, |ep_store, cx| {
 333        ep_store.register_buffer(&buffer, &project, cx);
 334    });
 335
 336    // First burst: insert "How"
 337    buffer.update(cx, |buffer, cx| {
 338        buffer.edit(vec![(7..7, "How")], None, cx);
 339    });
 340
 341    // Simulate a pause longer than the grouping threshold (e.g. 500ms).
 342    cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2);
 343    cx.run_until_parked();
 344
 345    // Second burst: append " are you?" immediately after "How" on the same line.
 346    //
 347    // Keeping both bursts on the same line ensures the existing line-span coalescing logic
 348    // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs.
 349    buffer.update(cx, |buffer, cx| {
 350        buffer.edit(vec![(10..10, " are you?")], None, cx);
 351    });
 352
 353    // A second edit shortly after the first post-pause edit ensures the last edit timestamp is
 354    // advanced after the pause boundary is recorded, making pause-splitting deterministic.
 355    buffer.update(cx, |buffer, cx| {
 356        buffer.edit(vec![(19..19, "!")], None, cx);
 357    });
 358
 359    // Without time-based splitting, there is one event.
 360    let events = ep_store.update(cx, |ep_store, cx| {
 361        ep_store.edit_history_for_project(&project, cx)
 362    });
 363    assert_eq!(events.len(), 1);
 364    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 365    assert_eq!(
 366        diff.as_str(),
 367        indoc! {"
 368            @@ -1,3 +1,3 @@
 369             Hello!
 370            -
 371            +How are you?!
 372             Bye
 373        "}
 374    );
 375
 376    // With time-based splitting, there are two distinct events.
 377    let events = ep_store.update(cx, |ep_store, cx| {
 378        ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
 379    });
 380    assert_eq!(events.len(), 2);
 381    let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
 382    assert_eq!(
 383        diff.as_str(),
 384        indoc! {"
 385            @@ -1,3 +1,3 @@
 386             Hello!
 387            -
 388            +How
 389             Bye
 390        "}
 391    );
 392
 393    let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
 394    assert_eq!(
 395        diff.as_str(),
 396        indoc! {"
 397            @@ -1,3 +1,3 @@
 398             Hello!
 399            -How
 400            +How are you?!
 401             Bye
 402        "}
 403    );
 404}
 405
 406#[gpui::test]
 407async fn test_event_grouping_line_span_coalescing(cx: &mut TestAppContext) {
 408    let (ep_store, _requests) = init_test_with_fake_client(cx);
 409    let fs = FakeFs::new(cx.executor());
 410
 411    // Create a file with 30 lines to test line-based coalescing
 412    let content = (1..=30)
 413        .map(|i| format!("Line {}\n", i))
 414        .collect::<String>();
 415    fs.insert_tree(
 416        "/root",
 417        json!({
 418            "foo.md": content
 419        }),
 420    )
 421    .await;
 422    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 423
 424    let buffer = project
 425        .update(cx, |project, cx| {
 426            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 427            project.open_buffer(path, cx)
 428        })
 429        .await
 430        .unwrap();
 431
 432    ep_store.update(cx, |ep_store, cx| {
 433        ep_store.register_buffer(&buffer, &project, cx);
 434    });
 435
 436    // First edit: multi-line edit spanning rows 10-12 (replacing lines 11-13)
 437    buffer.update(cx, |buffer, cx| {
 438        let start = Point::new(10, 0).to_offset(buffer);
 439        let end = Point::new(13, 0).to_offset(buffer);
 440        buffer.edit(vec![(start..end, "Middle A\nMiddle B\n")], None, cx);
 441    });
 442
 443    let events = ep_store.update(cx, |ep_store, cx| {
 444        ep_store.edit_history_for_project(&project, cx)
 445    });
 446    assert_eq!(
 447        render_events(&events),
 448        indoc! {"
 449            @@ -8,9 +8,8 @@
 450             Line 8
 451             Line 9
 452             Line 10
 453            -Line 11
 454            -Line 12
 455            -Line 13
 456            +Middle A
 457            +Middle B
 458             Line 14
 459             Line 15
 460             Line 16
 461        "},
 462        "After first edit"
 463    );
 464
 465    // Second edit: insert ABOVE the first edit's range (row 5, within 8 lines of row 10)
 466    // This tests that coalescing considers the START of the existing range
 467    buffer.update(cx, |buffer, cx| {
 468        let offset = Point::new(5, 0).to_offset(buffer);
 469        buffer.edit(vec![(offset..offset, "Above\n")], None, cx);
 470    });
 471
 472    let events = ep_store.update(cx, |ep_store, cx| {
 473        ep_store.edit_history_for_project(&project, cx)
 474    });
 475    assert_eq!(
 476        render_events(&events),
 477        indoc! {"
 478            @@ -3,14 +3,14 @@
 479             Line 3
 480             Line 4
 481             Line 5
 482            +Above
 483             Line 6
 484             Line 7
 485             Line 8
 486             Line 9
 487             Line 10
 488            -Line 11
 489            -Line 12
 490            -Line 13
 491            +Middle A
 492            +Middle B
 493             Line 14
 494             Line 15
 495             Line 16
 496        "},
 497        "After inserting above (should coalesce)"
 498    );
 499
 500    // Third edit: insert BELOW the first edit's range (row 14 in current buffer, within 8 lines of row 12)
 501    // This tests that coalescing considers the END of the existing range
 502    buffer.update(cx, |buffer, cx| {
 503        let offset = Point::new(14, 0).to_offset(buffer);
 504        buffer.edit(vec![(offset..offset, "Below\n")], None, cx);
 505    });
 506
 507    let events = ep_store.update(cx, |ep_store, cx| {
 508        ep_store.edit_history_for_project(&project, cx)
 509    });
 510    assert_eq!(
 511        render_events(&events),
 512        indoc! {"
 513            @@ -3,15 +3,16 @@
 514             Line 3
 515             Line 4
 516             Line 5
 517            +Above
 518             Line 6
 519             Line 7
 520             Line 8
 521             Line 9
 522             Line 10
 523            -Line 11
 524            -Line 12
 525            -Line 13
 526            +Middle A
 527            +Middle B
 528             Line 14
 529            +Below
 530             Line 15
 531             Line 16
 532             Line 17
 533        "},
 534        "After inserting below (should coalesce)"
 535    );
 536
 537    // Fourth edit: insert FAR BELOW (row 25, beyond 8 lines from the current range end ~row 15)
 538    // This should NOT coalesce - creates a new event
 539    buffer.update(cx, |buffer, cx| {
 540        let offset = Point::new(25, 0).to_offset(buffer);
 541        buffer.edit(vec![(offset..offset, "Far below\n")], None, cx);
 542    });
 543
 544    let events = ep_store.update(cx, |ep_store, cx| {
 545        ep_store.edit_history_for_project(&project, cx)
 546    });
 547    assert_eq!(
 548        render_events(&events),
 549        indoc! {"
 550            @@ -3,15 +3,16 @@
 551             Line 3
 552             Line 4
 553             Line 5
 554            +Above
 555             Line 6
 556             Line 7
 557             Line 8
 558             Line 9
 559             Line 10
 560            -Line 11
 561            -Line 12
 562            -Line 13
 563            +Middle A
 564            +Middle B
 565             Line 14
 566            +Below
 567             Line 15
 568             Line 16
 569             Line 17
 570
 571            ---
 572            @@ -23,6 +23,7 @@
 573             Line 22
 574             Line 23
 575             Line 24
 576            +Far below
 577             Line 25
 578             Line 26
 579             Line 27
 580        "},
 581        "After inserting far below (should NOT coalesce)"
 582    );
 583}
 584
 585fn render_events(events: &[StoredEvent]) -> String {
 586    events
 587        .iter()
 588        .map(|e| {
 589            let zeta_prompt::Event::BufferChange { diff, .. } = e.event.as_ref();
 590            diff.as_str()
 591        })
 592        .collect::<Vec<_>>()
 593        .join("\n---\n")
 594}
 595
 596#[gpui::test]
 597async fn test_empty_prediction(cx: &mut TestAppContext) {
 598    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 599    let fs = FakeFs::new(cx.executor());
 600    fs.insert_tree(
 601        "/root",
 602        json!({
 603            "foo.md":  "Hello!\nHow\nBye\n"
 604        }),
 605    )
 606    .await;
 607    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 608
 609    let buffer = project
 610        .update(cx, |project, cx| {
 611            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 612            project.open_buffer(path, cx)
 613        })
 614        .await
 615        .unwrap();
 616    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 617    let position = snapshot.anchor_before(language::Point::new(1, 3));
 618
 619    ep_store.update(cx, |ep_store, cx| {
 620        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 621    });
 622
 623    let (request, respond_tx) = requests.predict.next().await.unwrap();
 624    let response = model_response(&request, "");
 625    let id = response.request_id.clone();
 626    respond_tx.send(response).unwrap();
 627
 628    cx.run_until_parked();
 629
 630    ep_store.update(cx, |ep_store, cx| {
 631        assert!(
 632            ep_store
 633                .prediction_at(&buffer, None, &project, cx)
 634                .is_none()
 635        );
 636    });
 637
 638    // prediction is reported as rejected
 639    let (reject_request, _) = requests.reject.next().await.unwrap();
 640
 641    assert_eq!(
 642        &reject_request.rejections,
 643        &[EditPredictionRejection {
 644            request_id: id,
 645            reason: EditPredictionRejectReason::Empty,
 646            was_shown: false
 647        }]
 648    );
 649}
 650
 651#[gpui::test]
 652async fn test_interpolated_empty(cx: &mut TestAppContext) {
 653    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 654    let fs = FakeFs::new(cx.executor());
 655    fs.insert_tree(
 656        "/root",
 657        json!({
 658            "foo.md":  "Hello!\nHow\nBye\n"
 659        }),
 660    )
 661    .await;
 662    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 663
 664    let buffer = project
 665        .update(cx, |project, cx| {
 666            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 667            project.open_buffer(path, cx)
 668        })
 669        .await
 670        .unwrap();
 671    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 672    let position = snapshot.anchor_before(language::Point::new(1, 3));
 673
 674    ep_store.update(cx, |ep_store, cx| {
 675        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 676    });
 677
 678    let (request, respond_tx) = requests.predict.next().await.unwrap();
 679
 680    buffer.update(cx, |buffer, cx| {
 681        buffer.set_text("Hello!\nHow are you?\nBye", cx);
 682    });
 683
 684    let response = model_response(&request, SIMPLE_DIFF);
 685    let id = response.request_id.clone();
 686    respond_tx.send(response).unwrap();
 687
 688    cx.run_until_parked();
 689
 690    ep_store.update(cx, |ep_store, cx| {
 691        assert!(
 692            ep_store
 693                .prediction_at(&buffer, None, &project, cx)
 694                .is_none()
 695        );
 696    });
 697
 698    // prediction is reported as rejected
 699    let (reject_request, _) = requests.reject.next().await.unwrap();
 700
 701    assert_eq!(
 702        &reject_request.rejections,
 703        &[EditPredictionRejection {
 704            request_id: id,
 705            reason: EditPredictionRejectReason::InterpolatedEmpty,
 706            was_shown: false
 707        }]
 708    );
 709}
 710
 711const SIMPLE_DIFF: &str = indoc! { r"
 712    --- a/root/foo.md
 713    +++ b/root/foo.md
 714    @@ ... @@
 715     Hello!
 716    -How
 717    +How are you?
 718     Bye
 719"};
 720
 721#[gpui::test]
 722async fn test_replace_current(cx: &mut TestAppContext) {
 723    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 724    let fs = FakeFs::new(cx.executor());
 725    fs.insert_tree(
 726        "/root",
 727        json!({
 728            "foo.md":  "Hello!\nHow\nBye\n"
 729        }),
 730    )
 731    .await;
 732    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 733
 734    let buffer = project
 735        .update(cx, |project, cx| {
 736            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 737            project.open_buffer(path, cx)
 738        })
 739        .await
 740        .unwrap();
 741    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 742    let position = snapshot.anchor_before(language::Point::new(1, 3));
 743
 744    ep_store.update(cx, |ep_store, cx| {
 745        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 746    });
 747
 748    let (request, respond_tx) = requests.predict.next().await.unwrap();
 749    let first_response = model_response(&request, SIMPLE_DIFF);
 750    let first_id = first_response.request_id.clone();
 751    respond_tx.send(first_response).unwrap();
 752
 753    cx.run_until_parked();
 754
 755    ep_store.update(cx, |ep_store, cx| {
 756        assert_eq!(
 757            ep_store
 758                .prediction_at(&buffer, None, &project, cx)
 759                .unwrap()
 760                .id
 761                .0,
 762            first_id
 763        );
 764    });
 765
 766    // a second request is triggered
 767    ep_store.update(cx, |ep_store, cx| {
 768        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 769    });
 770
 771    let (request, respond_tx) = requests.predict.next().await.unwrap();
 772    let second_response = model_response(&request, SIMPLE_DIFF);
 773    let second_id = second_response.request_id.clone();
 774    respond_tx.send(second_response).unwrap();
 775
 776    cx.run_until_parked();
 777
 778    ep_store.update(cx, |ep_store, cx| {
 779        // second replaces first
 780        assert_eq!(
 781            ep_store
 782                .prediction_at(&buffer, None, &project, cx)
 783                .unwrap()
 784                .id
 785                .0,
 786            second_id
 787        );
 788    });
 789
 790    // first is reported as replaced
 791    let (reject_request, _) = requests.reject.next().await.unwrap();
 792
 793    assert_eq!(
 794        &reject_request.rejections,
 795        &[EditPredictionRejection {
 796            request_id: first_id,
 797            reason: EditPredictionRejectReason::Replaced,
 798            was_shown: false
 799        }]
 800    );
 801}
 802
 803#[gpui::test]
 804async fn test_current_preferred(cx: &mut TestAppContext) {
 805    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 806    let fs = FakeFs::new(cx.executor());
 807    fs.insert_tree(
 808        "/root",
 809        json!({
 810            "foo.md":  "Hello!\nHow\nBye\n"
 811        }),
 812    )
 813    .await;
 814    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 815
 816    let buffer = project
 817        .update(cx, |project, cx| {
 818            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 819            project.open_buffer(path, cx)
 820        })
 821        .await
 822        .unwrap();
 823    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 824    let position = snapshot.anchor_before(language::Point::new(1, 3));
 825
 826    ep_store.update(cx, |ep_store, cx| {
 827        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 828    });
 829
 830    let (request, respond_tx) = requests.predict.next().await.unwrap();
 831    let first_response = model_response(&request, SIMPLE_DIFF);
 832    let first_id = first_response.request_id.clone();
 833    respond_tx.send(first_response).unwrap();
 834
 835    cx.run_until_parked();
 836
 837    ep_store.update(cx, |ep_store, cx| {
 838        assert_eq!(
 839            ep_store
 840                .prediction_at(&buffer, None, &project, cx)
 841                .unwrap()
 842                .id
 843                .0,
 844            first_id
 845        );
 846    });
 847
 848    // a second request is triggered
 849    ep_store.update(cx, |ep_store, cx| {
 850        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 851    });
 852
 853    let (request, respond_tx) = requests.predict.next().await.unwrap();
 854    // worse than current prediction
 855    let second_response = model_response(
 856        &request,
 857        indoc! { r"
 858            --- a/root/foo.md
 859            +++ b/root/foo.md
 860            @@ ... @@
 861             Hello!
 862            -How
 863            +How are
 864             Bye
 865        "},
 866    );
 867    let second_id = second_response.request_id.clone();
 868    respond_tx.send(second_response).unwrap();
 869
 870    cx.run_until_parked();
 871
 872    ep_store.update(cx, |ep_store, cx| {
 873        // first is preferred over second
 874        assert_eq!(
 875            ep_store
 876                .prediction_at(&buffer, None, &project, cx)
 877                .unwrap()
 878                .id
 879                .0,
 880            first_id
 881        );
 882    });
 883
 884    // second is reported as rejected
 885    let (reject_request, _) = requests.reject.next().await.unwrap();
 886
 887    assert_eq!(
 888        &reject_request.rejections,
 889        &[EditPredictionRejection {
 890            request_id: second_id,
 891            reason: EditPredictionRejectReason::CurrentPreferred,
 892            was_shown: false
 893        }]
 894    );
 895}
 896
 897#[gpui::test]
 898async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
 899    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 900    let fs = FakeFs::new(cx.executor());
 901    fs.insert_tree(
 902        "/root",
 903        json!({
 904            "foo.md":  "Hello!\nHow\nBye\n"
 905        }),
 906    )
 907    .await;
 908    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 909
 910    let buffer = project
 911        .update(cx, |project, cx| {
 912            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 913            project.open_buffer(path, cx)
 914        })
 915        .await
 916        .unwrap();
 917    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 918    let position = snapshot.anchor_before(language::Point::new(1, 3));
 919
 920    // start two refresh tasks
 921    ep_store.update(cx, |ep_store, cx| {
 922        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 923    });
 924
 925    let (request1, respond_first) = requests.predict.next().await.unwrap();
 926
 927    ep_store.update(cx, |ep_store, cx| {
 928        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
 929    });
 930
 931    let (request, respond_second) = requests.predict.next().await.unwrap();
 932
 933    // wait for throttle
 934    cx.run_until_parked();
 935
 936    // second responds first
 937    let second_response = model_response(&request, SIMPLE_DIFF);
 938    let second_id = second_response.request_id.clone();
 939    respond_second.send(second_response).unwrap();
 940
 941    cx.run_until_parked();
 942
 943    ep_store.update(cx, |ep_store, cx| {
 944        // current prediction is second
 945        assert_eq!(
 946            ep_store
 947                .prediction_at(&buffer, None, &project, cx)
 948                .unwrap()
 949                .id
 950                .0,
 951            second_id
 952        );
 953    });
 954
 955    let first_response = model_response(&request1, SIMPLE_DIFF);
 956    let first_id = first_response.request_id.clone();
 957    respond_first.send(first_response).unwrap();
 958
 959    cx.run_until_parked();
 960
 961    ep_store.update(cx, |ep_store, cx| {
 962        // current prediction is still second, since first was cancelled
 963        assert_eq!(
 964            ep_store
 965                .prediction_at(&buffer, None, &project, cx)
 966                .unwrap()
 967                .id
 968                .0,
 969            second_id
 970        );
 971    });
 972
 973    // first is reported as rejected
 974    let (reject_request, _) = requests.reject.next().await.unwrap();
 975
 976    cx.run_until_parked();
 977
 978    assert_eq!(
 979        &reject_request.rejections,
 980        &[EditPredictionRejection {
 981            request_id: first_id,
 982            reason: EditPredictionRejectReason::Canceled,
 983            was_shown: false
 984        }]
 985    );
 986}
 987
 988#[gpui::test]
 989async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
 990    let (ep_store, mut requests) = init_test_with_fake_client(cx);
 991    let fs = FakeFs::new(cx.executor());
 992    fs.insert_tree(
 993        "/root",
 994        json!({
 995            "foo.md":  "Hello!\nHow\nBye\n"
 996        }),
 997    )
 998    .await;
 999    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1000
1001    let buffer = project
1002        .update(cx, |project, cx| {
1003            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1004            project.open_buffer(path, cx)
1005        })
1006        .await
1007        .unwrap();
1008    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1009    let position = snapshot.anchor_before(language::Point::new(1, 3));
1010
1011    // start two refresh tasks
1012    ep_store.update(cx, |ep_store, cx| {
1013        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1014    });
1015
1016    let (request1, respond_first) = requests.predict.next().await.unwrap();
1017
1018    ep_store.update(cx, |ep_store, cx| {
1019        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1020    });
1021
1022    let (request2, respond_second) = requests.predict.next().await.unwrap();
1023
1024    // wait for throttle, so requests are sent
1025    cx.run_until_parked();
1026
1027    ep_store.update(cx, |ep_store, cx| {
1028        // start a third request
1029        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1030
1031        // 2 are pending, so 2nd is cancelled
1032        assert_eq!(
1033            ep_store
1034                .get_or_init_project(&project, cx)
1035                .cancelled_predictions
1036                .iter()
1037                .copied()
1038                .collect::<Vec<_>>(),
1039            [1]
1040        );
1041    });
1042
1043    // wait for throttle
1044    cx.run_until_parked();
1045
1046    let (request3, respond_third) = requests.predict.next().await.unwrap();
1047
1048    let first_response = model_response(&request1, SIMPLE_DIFF);
1049    let first_id = first_response.request_id.clone();
1050    respond_first.send(first_response).unwrap();
1051
1052    cx.run_until_parked();
1053
1054    ep_store.update(cx, |ep_store, cx| {
1055        // current prediction is first
1056        assert_eq!(
1057            ep_store
1058                .prediction_at(&buffer, None, &project, cx)
1059                .unwrap()
1060                .id
1061                .0,
1062            first_id
1063        );
1064    });
1065
1066    let cancelled_response = model_response(&request2, SIMPLE_DIFF);
1067    let cancelled_id = cancelled_response.request_id.clone();
1068    respond_second.send(cancelled_response).unwrap();
1069
1070    cx.run_until_parked();
1071
1072    ep_store.update(cx, |ep_store, cx| {
1073        // current prediction is still first, since second was cancelled
1074        assert_eq!(
1075            ep_store
1076                .prediction_at(&buffer, None, &project, cx)
1077                .unwrap()
1078                .id
1079                .0,
1080            first_id
1081        );
1082    });
1083
1084    let third_response = model_response(&request3, SIMPLE_DIFF);
1085    let third_response_id = third_response.request_id.clone();
1086    respond_third.send(third_response).unwrap();
1087
1088    cx.run_until_parked();
1089
1090    ep_store.update(cx, |ep_store, cx| {
1091        // third completes and replaces first
1092        assert_eq!(
1093            ep_store
1094                .prediction_at(&buffer, None, &project, cx)
1095                .unwrap()
1096                .id
1097                .0,
1098            third_response_id
1099        );
1100    });
1101
1102    // second is reported as rejected
1103    let (reject_request, _) = requests.reject.next().await.unwrap();
1104
1105    cx.run_until_parked();
1106
1107    assert_eq!(
1108        &reject_request.rejections,
1109        &[
1110            EditPredictionRejection {
1111                request_id: cancelled_id,
1112                reason: EditPredictionRejectReason::Canceled,
1113                was_shown: false
1114            },
1115            EditPredictionRejection {
1116                request_id: first_id,
1117                reason: EditPredictionRejectReason::Replaced,
1118                was_shown: false
1119            }
1120        ]
1121    );
1122}
1123
1124#[gpui::test]
1125async fn test_rejections_flushing(cx: &mut TestAppContext) {
1126    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1127
1128    ep_store.update(cx, |ep_store, cx| {
1129        ep_store.reject_prediction(
1130            EditPredictionId("test-1".into()),
1131            EditPredictionRejectReason::Discarded,
1132            false,
1133            cx,
1134        );
1135        ep_store.reject_prediction(
1136            EditPredictionId("test-2".into()),
1137            EditPredictionRejectReason::Canceled,
1138            true,
1139            cx,
1140        );
1141    });
1142
1143    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1144    cx.run_until_parked();
1145
1146    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1147    respond_tx.send(()).unwrap();
1148
1149    // batched
1150    assert_eq!(reject_request.rejections.len(), 2);
1151    assert_eq!(
1152        reject_request.rejections[0],
1153        EditPredictionRejection {
1154            request_id: "test-1".to_string(),
1155            reason: EditPredictionRejectReason::Discarded,
1156            was_shown: false
1157        }
1158    );
1159    assert_eq!(
1160        reject_request.rejections[1],
1161        EditPredictionRejection {
1162            request_id: "test-2".to_string(),
1163            reason: EditPredictionRejectReason::Canceled,
1164            was_shown: true
1165        }
1166    );
1167
1168    // Reaching batch size limit sends without debounce
1169    ep_store.update(cx, |ep_store, cx| {
1170        for i in 0..70 {
1171            ep_store.reject_prediction(
1172                EditPredictionId(format!("batch-{}", i).into()),
1173                EditPredictionRejectReason::Discarded,
1174                false,
1175                cx,
1176            );
1177        }
1178    });
1179
1180    // First MAX/2 items are sent immediately
1181    cx.run_until_parked();
1182    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1183    respond_tx.send(()).unwrap();
1184
1185    assert_eq!(reject_request.rejections.len(), 50);
1186    assert_eq!(reject_request.rejections[0].request_id, "batch-0");
1187    assert_eq!(reject_request.rejections[49].request_id, "batch-49");
1188
1189    // Remaining items are debounced with the next batch
1190    cx.executor().advance_clock(Duration::from_secs(15));
1191    cx.run_until_parked();
1192
1193    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1194    respond_tx.send(()).unwrap();
1195
1196    assert_eq!(reject_request.rejections.len(), 20);
1197    assert_eq!(reject_request.rejections[0].request_id, "batch-50");
1198    assert_eq!(reject_request.rejections[19].request_id, "batch-69");
1199
1200    // Request failure
1201    ep_store.update(cx, |ep_store, cx| {
1202        ep_store.reject_prediction(
1203            EditPredictionId("retry-1".into()),
1204            EditPredictionRejectReason::Discarded,
1205            false,
1206            cx,
1207        );
1208    });
1209
1210    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1211    cx.run_until_parked();
1212
1213    let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
1214    assert_eq!(reject_request.rejections.len(), 1);
1215    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1216    // Simulate failure
1217    drop(_respond_tx);
1218
1219    // Add another rejection
1220    ep_store.update(cx, |ep_store, cx| {
1221        ep_store.reject_prediction(
1222            EditPredictionId("retry-2".into()),
1223            EditPredictionRejectReason::Discarded,
1224            false,
1225            cx,
1226        );
1227    });
1228
1229    cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
1230    cx.run_until_parked();
1231
1232    // Retry should include both the failed item and the new one
1233    let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
1234    respond_tx.send(()).unwrap();
1235
1236    assert_eq!(reject_request.rejections.len(), 2);
1237    assert_eq!(reject_request.rejections[0].request_id, "retry-1");
1238    assert_eq!(reject_request.rejections[1].request_id, "retry-2");
1239}
1240
1241// Skipped until we start including diagnostics in prompt
1242// #[gpui::test]
1243// async fn test_request_diagnostics(cx: &mut TestAppContext) {
1244//     let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
1245//     let fs = FakeFs::new(cx.executor());
1246//     fs.insert_tree(
1247//         "/root",
1248//         json!({
1249//             "foo.md": "Hello!\nBye"
1250//         }),
1251//     )
1252//     .await;
1253//     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1254
1255//     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1256//     let diagnostic = lsp::Diagnostic {
1257//         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1258//         severity: Some(lsp::DiagnosticSeverity::ERROR),
1259//         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1260//         ..Default::default()
1261//     };
1262
1263//     project.update(cx, |project, cx| {
1264//         project.lsp_store().update(cx, |lsp_store, cx| {
1265//             // Create some diagnostics
1266//             lsp_store
1267//                 .update_diagnostics(
1268//                     LanguageServerId(0),
1269//                     lsp::PublishDiagnosticsParams {
1270//                         uri: path_to_buffer_uri.clone(),
1271//                         diagnostics: vec![diagnostic],
1272//                         version: None,
1273//                     },
1274//                     None,
1275//                     language::DiagnosticSourceKind::Pushed,
1276//                     &[],
1277//                     cx,
1278//                 )
1279//                 .unwrap();
1280//         });
1281//     });
1282
1283//     let buffer = project
1284//         .update(cx, |project, cx| {
1285//             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1286//             project.open_buffer(path, cx)
1287//         })
1288//         .await
1289//         .unwrap();
1290
1291//     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1292//     let position = snapshot.anchor_before(language::Point::new(0, 0));
1293
1294//     let _prediction_task = ep_store.update(cx, |ep_store, cx| {
1295//         ep_store.request_prediction(&project, &buffer, position, cx)
1296//     });
1297
1298//     let (request, _respond_tx) = req_rx.next().await.unwrap();
1299
1300//     assert_eq!(request.diagnostic_groups.len(), 1);
1301//     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1302//         .unwrap();
1303//     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1304//     assert_eq!(
1305//         value,
1306//         json!({
1307//             "entries": [{
1308//                 "range": {
1309//                     "start": 8,
1310//                     "end": 10
1311//                 },
1312//                 "diagnostic": {
1313//                     "source": null,
1314//                     "code": null,
1315//                     "code_description": null,
1316//                     "severity": 1,
1317//                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1318//                     "markdown": null,
1319//                     "group_id": 0,
1320//                     "is_primary": true,
1321//                     "is_disk_based": false,
1322//                     "is_unnecessary": false,
1323//                     "source_kind": "Pushed",
1324//                     "data": null,
1325//                     "underline": true
1326//                 }
1327//             }],
1328//             "primary_ix": 0
1329//         })
1330//     );
1331// }
1332
1333// Generate a model response that would apply the given diff to the active file.
1334fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
1335    let excerpt =
1336        request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
1337    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
1338
1339    PredictEditsV3Response {
1340        request_id: Uuid::new_v4().to_string(),
1341        output: new_excerpt,
1342    }
1343}
1344
1345fn prompt_from_request(request: &PredictEditsV3Request) -> String {
1346    zeta_prompt::format_zeta_prompt(&request.input, request.prompt_version)
1347}
1348
1349struct RequestChannels {
1350    predict: mpsc::UnboundedReceiver<(
1351        PredictEditsV3Request,
1352        oneshot::Sender<PredictEditsV3Response>,
1353    )>,
1354    reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
1355}
1356
1357fn init_test_with_fake_client(
1358    cx: &mut TestAppContext,
1359) -> (Entity<EditPredictionStore>, RequestChannels) {
1360    cx.update(move |cx| {
1361        let settings_store = SettingsStore::test(cx);
1362        cx.set_global(settings_store);
1363        zlog::init_test();
1364
1365        let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
1366        let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
1367
1368        let http_client = FakeHttpClient::create({
1369            move |req| {
1370                let uri = req.uri().path().to_string();
1371                let mut body = req.into_body();
1372                let predict_req_tx = predict_req_tx.clone();
1373                let reject_req_tx = reject_req_tx.clone();
1374                async move {
1375                    let resp = match uri.as_str() {
1376                        "/client/llm_tokens" => serde_json::to_string(&json!({
1377                            "token": "test"
1378                        }))
1379                        .unwrap(),
1380                        "/predict_edits/v3" => {
1381                            let mut buf = Vec::new();
1382                            body.read_to_end(&mut buf).await.ok();
1383                            let req = serde_json::from_slice(&buf).unwrap();
1384
1385                            let (res_tx, res_rx) = oneshot::channel();
1386                            predict_req_tx.unbounded_send((req, res_tx)).unwrap();
1387                            serde_json::to_string(&res_rx.await?).unwrap()
1388                        }
1389                        "/predict_edits/reject" => {
1390                            let mut buf = Vec::new();
1391                            body.read_to_end(&mut buf).await.ok();
1392                            let req = serde_json::from_slice(&buf).unwrap();
1393
1394                            let (res_tx, res_rx) = oneshot::channel();
1395                            reject_req_tx.unbounded_send((req, res_tx)).unwrap();
1396                            serde_json::to_string(&res_rx.await?).unwrap()
1397                        }
1398                        _ => {
1399                            panic!("Unexpected path: {}", uri)
1400                        }
1401                    };
1402
1403                    Ok(Response::builder().body(resp.into()).unwrap())
1404                }
1405            }
1406        });
1407
1408        let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1409        client.cloud_client().set_credentials(1, "test".into());
1410
1411        language_model::init(client.clone(), cx);
1412
1413        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1414        let ep_store = EditPredictionStore::global(&client, &user_store, cx);
1415
1416        (
1417            ep_store,
1418            RequestChannels {
1419                predict: predict_req_rx,
1420                reject: reject_req_rx,
1421            },
1422        )
1423    })
1424}
1425
1426const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
1427
1428#[gpui::test]
1429async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1430    let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1431    let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
1432        to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
1433    });
1434
1435    let edit_preview = cx
1436        .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1437        .await;
1438
1439    let prediction = EditPrediction {
1440        edits,
1441        cursor_position: None,
1442        edit_preview,
1443        buffer: buffer.clone(),
1444        snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1445        id: EditPredictionId("the-id".into()),
1446        inputs: ZetaPromptInput {
1447            events: Default::default(),
1448            related_files: Default::default(),
1449            cursor_path: Path::new("").into(),
1450            cursor_excerpt: "".into(),
1451            editable_range_in_excerpt: 0..0,
1452            cursor_offset_in_excerpt: 0,
1453            excerpt_start_row: None,
1454        },
1455        buffer_snapshotted_at: Instant::now(),
1456        response_received_at: Instant::now(),
1457    };
1458
1459    cx.update(|cx| {
1460        assert_eq!(
1461            from_completion_edits(
1462                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1463                &buffer,
1464                cx
1465            ),
1466            vec![(2..5, "REM".into()), (9..11, "".into())]
1467        );
1468
1469        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1470        assert_eq!(
1471            from_completion_edits(
1472                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1473                &buffer,
1474                cx
1475            ),
1476            vec![(2..2, "REM".into()), (6..8, "".into())]
1477        );
1478
1479        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1480        assert_eq!(
1481            from_completion_edits(
1482                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1483                &buffer,
1484                cx
1485            ),
1486            vec![(2..5, "REM".into()), (9..11, "".into())]
1487        );
1488
1489        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1490        assert_eq!(
1491            from_completion_edits(
1492                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1493                &buffer,
1494                cx
1495            ),
1496            vec![(3..3, "EM".into()), (7..9, "".into())]
1497        );
1498
1499        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1500        assert_eq!(
1501            from_completion_edits(
1502                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1503                &buffer,
1504                cx
1505            ),
1506            vec![(4..4, "M".into()), (8..10, "".into())]
1507        );
1508
1509        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1510        assert_eq!(
1511            from_completion_edits(
1512                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1513                &buffer,
1514                cx
1515            ),
1516            vec![(9..11, "".into())]
1517        );
1518
1519        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1520        assert_eq!(
1521            from_completion_edits(
1522                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1523                &buffer,
1524                cx
1525            ),
1526            vec![(4..4, "M".into()), (8..10, "".into())]
1527        );
1528
1529        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1530        assert_eq!(
1531            from_completion_edits(
1532                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1533                &buffer,
1534                cx
1535            ),
1536            vec![(4..4, "M".into())]
1537        );
1538
1539        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1540        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1541    })
1542}
1543
1544#[gpui::test]
1545async fn test_clean_up_diff(cx: &mut TestAppContext) {
1546    init_test(cx);
1547
1548    assert_eq!(
1549        apply_edit_prediction(
1550            indoc! {"
1551                    fn main() {
1552                        let word_1 = \"lorem\";
1553                        let range = word.len()..word.len();
1554                    }
1555                "},
1556            indoc! {"
1557                    <|editable_region_start|>
1558                    fn main() {
1559                        let word_1 = \"lorem\";
1560                        let range = word_1.len()..word_1.len();
1561                    }
1562
1563                    <|editable_region_end|>
1564                "},
1565            cx,
1566        )
1567        .await,
1568        indoc! {"
1569                fn main() {
1570                    let word_1 = \"lorem\";
1571                    let range = word_1.len()..word_1.len();
1572                }
1573            "},
1574    );
1575
1576    assert_eq!(
1577        apply_edit_prediction(
1578            indoc! {"
1579                    fn main() {
1580                        let story = \"the quick\"
1581                    }
1582                "},
1583            indoc! {"
1584                    <|editable_region_start|>
1585                    fn main() {
1586                        let story = \"the quick brown fox jumps over the lazy dog\";
1587                    }
1588
1589                    <|editable_region_end|>
1590                "},
1591            cx,
1592        )
1593        .await,
1594        indoc! {"
1595                fn main() {
1596                    let story = \"the quick brown fox jumps over the lazy dog\";
1597                }
1598            "},
1599    );
1600}
1601
1602#[gpui::test]
1603async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
1604    init_test(cx);
1605
1606    let buffer_content = "lorem\n";
1607    let completion_response = indoc! {"
1608            ```animals.js
1609            <|start_of_file|>
1610            <|editable_region_start|>
1611            lorem
1612            ipsum
1613            <|editable_region_end|>
1614            ```"};
1615
1616    assert_eq!(
1617        apply_edit_prediction(buffer_content, completion_response, cx).await,
1618        "lorem\nipsum"
1619    );
1620}
1621
1622#[gpui::test]
1623async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
1624    // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
1625    // When the buffer ends without a trailing newline, but the model returns output
1626    // with a trailing newline, zeta2 should normalize both sides before diffing
1627    // so no spurious newline is inserted.
1628    let (ep_store, mut requests) = init_test_with_fake_client(cx);
1629    let fs = FakeFs::new(cx.executor());
1630
1631    // Single line buffer with no trailing newline
1632    fs.insert_tree(
1633        "/root",
1634        json!({
1635            "foo.txt": "hello"
1636        }),
1637    )
1638    .await;
1639    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1640
1641    let buffer = project
1642        .update(cx, |project, cx| {
1643            let path = project
1644                .find_project_path(path!("root/foo.txt"), cx)
1645                .unwrap();
1646            project.open_buffer(path, cx)
1647        })
1648        .await
1649        .unwrap();
1650
1651    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1652    let position = snapshot.anchor_before(language::Point::new(0, 5));
1653
1654    ep_store.update(cx, |ep_store, cx| {
1655        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
1656    });
1657
1658    let (_request, respond_tx) = requests.predict.next().await.unwrap();
1659
1660    // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
1661    // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
1662    let response = PredictEditsV3Response {
1663        request_id: Uuid::new_v4().to_string(),
1664        output: "hello world\n".to_string(),
1665    };
1666    respond_tx.send(response).unwrap();
1667
1668    cx.run_until_parked();
1669
1670    // The prediction should insert " world" without adding a newline
1671    ep_store.update(cx, |ep_store, cx| {
1672        let prediction = ep_store
1673            .prediction_at(&buffer, None, &project, cx)
1674            .expect("should have prediction");
1675        let edits: Vec<_> = prediction
1676            .edits
1677            .iter()
1678            .map(|(range, text)| {
1679                let snapshot = buffer.read(cx).snapshot();
1680                (range.to_offset(&snapshot), text.clone())
1681            })
1682            .collect();
1683        assert_eq!(edits, vec![(5..5, " world".into())]);
1684    });
1685}
1686
1687#[gpui::test]
1688async fn test_can_collect_data(cx: &mut TestAppContext) {
1689    init_test(cx);
1690
1691    let fs = project::FakeFs::new(cx.executor());
1692    fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
1693        .await;
1694
1695    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1696    let buffer = project
1697        .update(cx, |project, cx| {
1698            project.open_local_buffer(path!("/project/src/main.rs"), cx)
1699        })
1700        .await
1701        .unwrap();
1702
1703    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1704    ep_store.update(cx, |ep_store, _cx| {
1705        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1706    });
1707
1708    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1709    assert_eq!(
1710        captured_request.lock().clone().unwrap().can_collect_data,
1711        true
1712    );
1713
1714    ep_store.update(cx, |ep_store, _cx| {
1715        ep_store.data_collection_choice = DataCollectionChoice::Disabled
1716    });
1717
1718    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1719    assert_eq!(
1720        captured_request.lock().clone().unwrap().can_collect_data,
1721        false
1722    );
1723}
1724
1725#[gpui::test]
1726async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
1727    init_test(cx);
1728
1729    let fs = project::FakeFs::new(cx.executor());
1730    let project = Project::test(fs.clone(), [], cx).await;
1731
1732    let buffer = cx.new(|_cx| {
1733        Buffer::remote(
1734            language::BufferId::new(1).unwrap(),
1735            ReplicaId::new(1),
1736            language::Capability::ReadWrite,
1737            "fn main() {\n    println!(\"Hello\");\n}",
1738        )
1739    });
1740
1741    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1742    ep_store.update(cx, |ep_store, _cx| {
1743        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1744    });
1745
1746    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1747    assert_eq!(
1748        captured_request.lock().clone().unwrap().can_collect_data,
1749        false
1750    );
1751}
1752
1753#[gpui::test]
1754async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
1755    init_test(cx);
1756
1757    let fs = project::FakeFs::new(cx.executor());
1758    fs.insert_tree(
1759        path!("/project"),
1760        json!({
1761            "LICENSE": BSD_0_TXT,
1762            ".env": "SECRET_KEY=secret"
1763        }),
1764    )
1765    .await;
1766
1767    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1768    let buffer = project
1769        .update(cx, |project, cx| {
1770            project.open_local_buffer("/project/.env", cx)
1771        })
1772        .await
1773        .unwrap();
1774
1775    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1776    ep_store.update(cx, |ep_store, _cx| {
1777        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1778    });
1779
1780    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1781    assert_eq!(
1782        captured_request.lock().clone().unwrap().can_collect_data,
1783        false
1784    );
1785}
1786
1787#[gpui::test]
1788async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
1789    init_test(cx);
1790
1791    let fs = project::FakeFs::new(cx.executor());
1792    let project = Project::test(fs.clone(), [], cx).await;
1793    let buffer = cx.new(|cx| Buffer::local("", cx));
1794
1795    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1796    ep_store.update(cx, |ep_store, _cx| {
1797        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1798    });
1799
1800    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1801    assert_eq!(
1802        captured_request.lock().clone().unwrap().can_collect_data,
1803        false
1804    );
1805}
1806
1807#[gpui::test]
1808async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
1809    init_test(cx);
1810
1811    let fs = project::FakeFs::new(cx.executor());
1812    fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
1813        .await;
1814
1815    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1816    let buffer = project
1817        .update(cx, |project, cx| {
1818            project.open_local_buffer("/project/main.rs", cx)
1819        })
1820        .await
1821        .unwrap();
1822
1823    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1824    ep_store.update(cx, |ep_store, _cx| {
1825        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1826    });
1827
1828    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1829    assert_eq!(
1830        captured_request.lock().clone().unwrap().can_collect_data,
1831        false
1832    );
1833}
1834
1835#[gpui::test]
1836async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
1837    init_test(cx);
1838
1839    let fs = project::FakeFs::new(cx.executor());
1840    fs.insert_tree(
1841        path!("/open_source_worktree"),
1842        json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
1843    )
1844    .await;
1845    fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
1846        .await;
1847
1848    let project = Project::test(
1849        fs.clone(),
1850        [
1851            path!("/open_source_worktree").as_ref(),
1852            path!("/closed_source_worktree").as_ref(),
1853        ],
1854        cx,
1855    )
1856    .await;
1857    let buffer = project
1858        .update(cx, |project, cx| {
1859            project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
1860        })
1861        .await
1862        .unwrap();
1863
1864    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1865    ep_store.update(cx, |ep_store, _cx| {
1866        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1867    });
1868
1869    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1870    assert_eq!(
1871        captured_request.lock().clone().unwrap().can_collect_data,
1872        true
1873    );
1874
1875    let closed_source_file = project
1876        .update(cx, |project, cx| {
1877            let worktree2 = project
1878                .worktree_for_root_name("closed_source_worktree", cx)
1879                .unwrap();
1880            worktree2.update(cx, |worktree2, cx| {
1881                worktree2.load_file(rel_path("main.rs"), cx)
1882            })
1883        })
1884        .await
1885        .unwrap()
1886        .file;
1887
1888    buffer.update(cx, |buffer, cx| {
1889        buffer.file_updated(closed_source_file, cx);
1890    });
1891
1892    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1893    assert_eq!(
1894        captured_request.lock().clone().unwrap().can_collect_data,
1895        false
1896    );
1897}
1898
1899#[gpui::test]
1900async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
1901    init_test(cx);
1902
1903    let fs = project::FakeFs::new(cx.executor());
1904    fs.insert_tree(
1905        path!("/worktree1"),
1906        json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
1907    )
1908    .await;
1909    fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
1910        .await;
1911
1912    let project = Project::test(
1913        fs.clone(),
1914        [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
1915        cx,
1916    )
1917    .await;
1918    let buffer = project
1919        .update(cx, |project, cx| {
1920            project.open_local_buffer(path!("/worktree1/main.rs"), cx)
1921        })
1922        .await
1923        .unwrap();
1924    let private_buffer = project
1925        .update(cx, |project, cx| {
1926            project.open_local_buffer(path!("/worktree2/file.rs"), cx)
1927        })
1928        .await
1929        .unwrap();
1930
1931    let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
1932    ep_store.update(cx, |ep_store, _cx| {
1933        ep_store.data_collection_choice = DataCollectionChoice::Enabled
1934    });
1935
1936    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1937    assert_eq!(
1938        captured_request.lock().clone().unwrap().can_collect_data,
1939        true
1940    );
1941
1942    // this has a side effect of registering the buffer to watch for edits
1943    run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
1944    assert_eq!(
1945        captured_request.lock().clone().unwrap().can_collect_data,
1946        false
1947    );
1948
1949    private_buffer.update(cx, |private_buffer, cx| {
1950        private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
1951    });
1952
1953    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1954    assert_eq!(
1955        captured_request.lock().clone().unwrap().can_collect_data,
1956        false
1957    );
1958
1959    // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
1960    // included
1961    buffer.update(cx, |buffer, cx| {
1962        buffer.edit(
1963            [(
1964                0..0,
1965                " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
1966            )],
1967            None,
1968            cx,
1969        );
1970    });
1971
1972    run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1973    assert_eq!(
1974        captured_request.lock().clone().unwrap().can_collect_data,
1975        true
1976    );
1977}
1978
1979fn init_test(cx: &mut TestAppContext) {
1980    cx.update(|cx| {
1981        let settings_store = SettingsStore::test(cx);
1982        cx.set_global(settings_store);
1983    });
1984}
1985
1986async fn apply_edit_prediction(
1987    buffer_content: &str,
1988    completion_response: &str,
1989    cx: &mut TestAppContext,
1990) -> String {
1991    let fs = project::FakeFs::new(cx.executor());
1992    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1993    let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
1994    let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
1995    *response.lock() = completion_response.to_string();
1996    let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
1997    buffer.update(cx, |buffer, cx| {
1998        buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
1999    });
2000    buffer.read_with(cx, |buffer, _| buffer.text())
2001}
2002
2003async fn run_edit_prediction(
2004    buffer: &Entity<Buffer>,
2005    project: &Entity<Project>,
2006    ep_store: &Entity<EditPredictionStore>,
2007    cx: &mut TestAppContext,
2008) -> EditPrediction {
2009    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
2010    ep_store.update(cx, |ep_store, cx| {
2011        ep_store.register_buffer(buffer, &project, cx)
2012    });
2013    cx.background_executor.run_until_parked();
2014    let prediction_task = ep_store.update(cx, |ep_store, cx| {
2015        ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
2016    });
2017    prediction_task.await.unwrap().unwrap().prediction.unwrap()
2018}
2019
2020async fn make_test_ep_store(
2021    project: &Entity<Project>,
2022    cx: &mut TestAppContext,
2023) -> (
2024    Entity<EditPredictionStore>,
2025    Arc<Mutex<Option<PredictEditsBody>>>,
2026    Arc<Mutex<String>>,
2027) {
2028    let default_response = indoc! {"
2029            ```main.rs
2030            <|start_of_file|>
2031            <|editable_region_start|>
2032            hello world
2033            <|editable_region_end|>
2034            ```"
2035    };
2036    let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
2037    let completion_response: Arc<Mutex<String>> =
2038        Arc::new(Mutex::new(default_response.to_string()));
2039    let http_client = FakeHttpClient::create({
2040        let captured_request = captured_request.clone();
2041        let completion_response = completion_response.clone();
2042        let mut next_request_id = 0;
2043        move |req| {
2044            let captured_request = captured_request.clone();
2045            let completion_response = completion_response.clone();
2046            async move {
2047                match (req.method(), req.uri().path()) {
2048                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
2049                        .status(200)
2050                        .body(
2051                            serde_json::to_string(&CreateLlmTokenResponse {
2052                                token: LlmToken("the-llm-token".to_string()),
2053                            })
2054                            .unwrap()
2055                            .into(),
2056                        )
2057                        .unwrap()),
2058                    (&Method::POST, "/predict_edits/v2") => {
2059                        let mut request_body = String::new();
2060                        req.into_body().read_to_string(&mut request_body).await?;
2061                        *captured_request.lock() =
2062                            Some(serde_json::from_str(&request_body).unwrap());
2063                        next_request_id += 1;
2064                        Ok(http_client::Response::builder()
2065                            .status(200)
2066                            .body(
2067                                serde_json::to_string(&PredictEditsResponse {
2068                                    request_id: format!("request-{next_request_id}"),
2069                                    output_excerpt: completion_response.lock().clone(),
2070                                })
2071                                .unwrap()
2072                                .into(),
2073                            )
2074                            .unwrap())
2075                    }
2076                    _ => Ok(http_client::Response::builder()
2077                        .status(404)
2078                        .body("Not Found".into())
2079                        .unwrap()),
2080                }
2081            }
2082        }
2083    });
2084
2085    let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2086    cx.update(|cx| {
2087        RefreshLlmTokenListener::register(client.clone(), cx);
2088    });
2089    let _server = FakeServer::for_client(42, &client, cx).await;
2090
2091    let ep_store = cx.new(|cx| {
2092        let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
2093        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2094
2095        let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
2096        for worktree in worktrees {
2097            let worktree_id = worktree.read(cx).id();
2098            ep_store
2099                .get_or_init_project(project, cx)
2100                .license_detection_watchers
2101                .entry(worktree_id)
2102                .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
2103        }
2104
2105        ep_store
2106    });
2107
2108    (ep_store, captured_request, completion_response)
2109}
2110
2111fn to_completion_edits(
2112    iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
2113    buffer: &Entity<Buffer>,
2114    cx: &App,
2115) -> Vec<(Range<Anchor>, Arc<str>)> {
2116    let buffer = buffer.read(cx);
2117    iterator
2118        .into_iter()
2119        .map(|(range, text)| {
2120            (
2121                buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
2122                text,
2123            )
2124        })
2125        .collect()
2126}
2127
2128fn from_completion_edits(
2129    editor_edits: &[(Range<Anchor>, Arc<str>)],
2130    buffer: &Entity<Buffer>,
2131    cx: &App,
2132) -> Vec<(Range<usize>, Arc<str>)> {
2133    let buffer = buffer.read(cx);
2134    editor_edits
2135        .iter()
2136        .map(|(range, text)| {
2137            (
2138                range.start.to_offset(buffer)..range.end.to_offset(buffer),
2139                text.clone(),
2140            )
2141        })
2142        .collect()
2143}
2144
2145#[gpui::test]
2146async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
2147    init_test(cx);
2148
2149    let fs = FakeFs::new(cx.executor());
2150    fs.insert_tree(
2151        "/project",
2152        serde_json::json!({
2153            "main.rs": "fn main() {\n    \n}\n"
2154        }),
2155    )
2156    .await;
2157
2158    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2159
2160    let http_client = FakeHttpClient::create(|_req| async move {
2161        Ok(gpui::http_client::Response::builder()
2162            .status(401)
2163            .body("Unauthorized".into())
2164            .unwrap())
2165    });
2166
2167    let client =
2168        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2169    cx.update(|cx| {
2170        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2171    });
2172
2173    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2174
2175    let buffer = project
2176        .update(cx, |project, cx| {
2177            let path = project
2178                .find_project_path(path!("/project/main.rs"), cx)
2179                .unwrap();
2180            project.open_buffer(path, cx)
2181        })
2182        .await
2183        .unwrap();
2184
2185    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2186    ep_store.update(cx, |ep_store, cx| {
2187        ep_store.register_buffer(&buffer, &project, cx)
2188    });
2189    cx.background_executor.run_until_parked();
2190
2191    let completion_task = ep_store.update(cx, |ep_store, cx| {
2192        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2193        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2194    });
2195
2196    let result = completion_task.await;
2197    assert!(
2198        result.is_err(),
2199        "Without authentication and without custom URL, prediction should fail"
2200    );
2201}
2202
2203#[gpui::test]
2204async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
2205    init_test(cx);
2206
2207    let fs = FakeFs::new(cx.executor());
2208    fs.insert_tree(
2209        "/project",
2210        serde_json::json!({
2211            "main.rs": "fn main() {\n    \n}\n"
2212        }),
2213    )
2214    .await;
2215
2216    let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
2217
2218    let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
2219    let predict_called_clone = predict_called.clone();
2220
2221    let http_client = FakeHttpClient::create({
2222        move |req| {
2223            let uri = req.uri().path().to_string();
2224            let predict_called = predict_called_clone.clone();
2225            async move {
2226                if uri.contains("predict") {
2227                    predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
2228                    Ok(gpui::http_client::Response::builder()
2229                        .body(
2230                            serde_json::to_string(&open_ai::Response {
2231                                id: "test-123".to_string(),
2232                                object: "chat.completion".to_string(),
2233                                created: 0,
2234                                model: "test".to_string(),
2235                                usage: open_ai::Usage {
2236                                    prompt_tokens: 0,
2237                                    completion_tokens: 0,
2238                                    total_tokens: 0,
2239                                },
2240                                choices: vec![open_ai::Choice {
2241                                    index: 0,
2242                                    message: open_ai::RequestMessage::Assistant {
2243                                        content: Some(open_ai::MessageContent::Plain(
2244                                            indoc! {"
2245                                                ```main.rs
2246                                                <|start_of_file|>
2247                                                <|editable_region_start|>
2248                                                fn main() {
2249                                                    println!(\"Hello, world!\");
2250                                                }
2251                                                <|editable_region_end|>
2252                                                ```
2253                                            "}
2254                                            .to_string(),
2255                                        )),
2256                                        tool_calls: vec![],
2257                                    },
2258                                    finish_reason: Some("stop".to_string()),
2259                                }],
2260                            })
2261                            .unwrap()
2262                            .into(),
2263                        )
2264                        .unwrap())
2265                } else {
2266                    Ok(gpui::http_client::Response::builder()
2267                        .status(401)
2268                        .body("Unauthorized".into())
2269                        .unwrap())
2270                }
2271            }
2272        }
2273    });
2274
2275    let client =
2276        cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
2277    cx.update(|cx| {
2278        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
2279    });
2280
2281    let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
2282
2283    let buffer = project
2284        .update(cx, |project, cx| {
2285            let path = project
2286                .find_project_path(path!("/project/main.rs"), cx)
2287                .unwrap();
2288            project.open_buffer(path, cx)
2289        })
2290        .await
2291        .unwrap();
2292
2293    let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
2294    ep_store.update(cx, |ep_store, cx| {
2295        ep_store.register_buffer(&buffer, &project, cx)
2296    });
2297    cx.background_executor.run_until_parked();
2298
2299    let completion_task = ep_store.update(cx, |ep_store, cx| {
2300        ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
2301        ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
2302        ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
2303    });
2304
2305    let _ = completion_task.await;
2306
2307    assert!(
2308        predict_called.load(std::sync::atomic::Ordering::SeqCst),
2309        "With custom URL, predict endpoint should be called even without authentication"
2310    );
2311}
2312
2313#[gpui::test]
2314fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
2315    let buffer = cx.new(|cx| {
2316        Buffer::local(
2317            indoc! {"
2318                zero
2319                one
2320                two
2321                three
2322                four
2323                five
2324                six
2325                seven
2326                eight
2327                nine
2328                ten
2329                eleven
2330                twelve
2331                thirteen
2332                fourteen
2333                fifteen
2334                sixteen
2335                seventeen
2336                eighteen
2337                nineteen
2338                twenty
2339                twenty-one
2340                twenty-two
2341                twenty-three
2342                twenty-four
2343            "},
2344            cx,
2345        )
2346    });
2347
2348    let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2349
2350    buffer.update(cx, |buffer, cx| {
2351        let point = Point::new(12, 0);
2352        buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
2353        let point = Point::new(8, 0);
2354        buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
2355    });
2356
2357    let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
2358
2359    let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
2360
2361    assert_eq!(
2362        diff,
2363        indoc! {"
2364            @@ -6,10 +6,12 @@
2365             five
2366             six
2367             seven
2368            +FIRST INSERTION
2369             eight
2370             nine
2371             ten
2372             eleven
2373            +SECOND INSERTION
2374             twelve
2375             thirteen
2376             fourteen
2377            "}
2378    );
2379}
2380
2381#[ctor::ctor]
2382fn init_logger() {
2383    zlog::init_test();
2384}